Understand TensorFlow tf.while_loop(): Repeat a Function – TensorFlow Tutorial

By | June 17, 2020

TensorFlow tf.while_loop() can allow us to repeat a function. In this tutorial, we will introduce how to use this function with some examples.

Syntax

tf.while_loop(
    cond,
    body,
    loop_vars,
    shape_invariants=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    name=None,
    maximum_iterations=None
)

Repeat body while the condition cond is true.

In order to use this function correctly, we should notice these parameters.

cond: the condition of the loop, the value of it may be True or False

body: the callable that represents the loop body.

loop_vars: the tuple or list, the initialized cond.

back_prop: whether backprop is enabled for this while loop.

How to understand tf.while_loop() ?

tf.while_loop() can be regarded as:

def while_loop(cond, body, loop_vars): 
    state = loop_vars
    while(cond(state)) : 
        state = body(state)  
    return state

From this code, we can find:

1.cond is callable, such as a lamda or function

2.body is also callable, it will return a state, the number of state is same to loop_var

3.loop_var will be passed into cond to executed and returns False or True

We will use some examples to illustrate how to use tf.while_loop() function.

Cumulative sum of a tensor

We will calculate the cumulative sum of a tensor using tf.while_loop() function.

import tensorflow as tf
import numpy as np

# the shape of x1 is (3, 4)
x1 = tf.Variable(np.array([[2, 2, 3, 4], [1, 5, 3, 2] ,[6, 7, 2, 1]], dtype = np.float32), name = 'x1')

Get the maximum loop count.

max_loop = tf.shape(x1)[0]

Then create a function to calculate the cumulative sum of a tensor

def cumulativeSum(i, pre_ele):
    next_ele = tf.nn.embedding_lookup(x1,i) # get the next ele in x1
    result = pre_ele + next_ele
    i_next = i + 1
    return i_next, result

where i is used to control the loop and pre_ele is used save terminal result.

init_result = tf.Variable(np.array([0, 0, 0, 0], dtype = np.float32))
_, result = tf.while_loop(cond = lambda i, result: tf.less(i, max_loop),
                          body = cumulativeSum,
                          loop_vars = (
                                       tf.constant(0, dtype=tf.int32),
                                       init_result
                                       )
                          )

The example code can be regarded as:

i = 0
result = tf.Variable(np.array([0, 0, 0, 0], dtype = np.float32))
init_state = (i, result)

max_loop = 3

while(i < max_loop):
    i, result = cumulativeSum(i,x1[i])

We can output the result as below:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(result))

Run this code, we can get the cumulative sum of x1 is:

[ 9. 14.  8.  7.]

Leave a Reply