Understand TensorFlow tf.while_loop(): Repeat a Function – TensorFlow Tutorial

By | June 17, 2020

TensorFlow tf.while_loop() can allow us to repeat a function. In this tutorial, we will introduce how to use this function with some examples.



Repeat body while the condition cond is true.

In order to use this function correctly, we should notice these parameters.

cond: the condition of the loop, the value of it may be True or False

body: the callable that represents the loop body.

loop_vars: the tuple or list, the initialized cond.

back_prop: whether backprop is enabled for this while loop.

How to understand tf.while_loop() ?

tf.while_loop() can be regarded as:

def while_loop(cond, body, loop_vars): 
    state = loop_vars
    while(cond(state)) : 
        state = body(state)  
    return state

From this code, we can find:

1.cond is callable, such as a lamda or function

2.body is also callable, it will return a state, the number of state is same to loop_var

3.loop_var will be passed into cond to executed and returns False or True

We will use some examples to illustrate how to use tf.while_loop() function.

Cumulative sum of a tensor

We will calculate the cumulative sum of a tensor using tf.while_loop() function.

import tensorflow as tf
import numpy as np

# the shape of x1 is (3, 4)
x1 = tf.Variable(np.array([[2, 2, 3, 4], [1, 5, 3, 2] ,[6, 7, 2, 1]], dtype = np.float32), name = 'x1')

Get the maximum loop count.

max_loop = tf.shape(x1)[0]

Then create a function to calculate the cumulative sum of a tensor

def cumulativeSum(i, pre_ele):
    next_ele = tf.nn.embedding_lookup(x1,i) # get the next ele in x1
    result = pre_ele + next_ele
    i_next = i + 1
    return i_next, result

where i is used to control the loop and pre_ele is used save terminal result.

init_result = tf.Variable(np.array([0, 0, 0, 0], dtype = np.float32))
_, result = tf.while_loop(cond = lambda i, result: tf.less(i, max_loop),
                          body = cumulativeSum,
                          loop_vars = (
                                       tf.constant(0, dtype=tf.int32),

The example code can be regarded as:

i = 0
result = tf.Variable(np.array([0, 0, 0, 0], dtype = np.float32))
init_state = (i, result)

max_loop = 3

while(i < max_loop):
    i, result = cumulativeSum(i,x1[i])

We can output the result as below:

with tf.Session() as sess:

Run this code, we can get the cumulative sum of x1 is:

[ 9. 14.  8.  7.]

Leave a Reply