TensorFlow tf.while_loop() can allow us to repeat a function. In this tutorial, we will introduce how to use this function with some examples.
tf.while_loop( cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None, maximum_iterations=None )
Repeat body while the condition cond is true.
In order to use this function correctly, we should notice these parameters.
cond: the condition of the loop, the value of it may be True or False
body: the callable that represents the loop body.
loop_vars: the tuple or list, the initialized cond.
back_prop: whether backprop is enabled for this while loop.
How to understand tf.while_loop() ?
tf.while_loop() can be regarded as:
def while_loop(cond, body, loop_vars): state = loop_vars while(cond(state)) : state = body(state) return state
From this code, we can find:
1.cond is callable, such as a lamda or function
2.body is also callable, it will return a state, the number of state is same to loop_var
3.loop_var will be passed into cond to executed and returns False or True
We will use some examples to illustrate how to use tf.while_loop() function.
Cumulative sum of a tensor
We will calculate the cumulative sum of a tensor using tf.while_loop() function.
import tensorflow as tf import numpy as np # the shape of x1 is (3, 4) x1 = tf.Variable(np.array([[2, 2, 3, 4], [1, 5, 3, 2] ,[6, 7, 2, 1]], dtype = np.float32), name = 'x1')
Get the maximum loop count.
max_loop = tf.shape(x1)[0]
Then create a function to calculate the cumulative sum of a tensor
def cumulativeSum(i, pre_ele): next_ele = tf.nn.embedding_lookup(x1,i) # get the next ele in x1 result = pre_ele + next_ele i_next = i + 1 return i_next, result
where i is used to control the loop and pre_ele is used save terminal result.
init_result = tf.Variable(np.array([0, 0, 0, 0], dtype = np.float32)) _, result = tf.while_loop(cond = lambda i, result: tf.less(i, max_loop), body = cumulativeSum, loop_vars = ( tf.constant(0, dtype=tf.int32), init_result ) )
The example code can be regarded as:
i = 0 result = tf.Variable(np.array([0, 0, 0, 0], dtype = np.float32)) init_state = (i, result) max_loop = 3 while(i < max_loop): i, result = cumulativeSum(i,x1[i])
We can output the result as below:
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(result))
Run this code, we can get the cumulative sum of x1 is:
[ 9. 14. 8. 7.]