In this tutorial, we will introduce how to run tensorflow function by condition. It is very useful to build tensorflow model.
For example, we may need this kind of logical application.
flag = True if flag: run_fun_1() else: run_fun_2()
function run_fun_1 or run_fun_12 are determined by flag.
In tensorflow, we can use tf.cond() to implement the some function.
Syntax
tf.cond( pred, true_fn=None, false_fn=None, strict=False, name=None, fn1=None, fn2=None )
Return true_fn() if the predicate pred is true else false_fn()
It means:
pred = True, run true_fn and return its value
pred = False, run false_fn and return its value
We should notice:
- pred is a tensor, which can be True or False
- true_fn and false_fn must be a tensorflow function
For example:
import tensorflow as tf import numpy as np t1 = tf.Variable(np.array([[[200, 4, 5], [20, 5, 70]],[[2, 3, 5], [5, 5, 7]]]), dtype = tf.float32, name = 'lables') t2 = tf.Variable(np.array([1,2]), dtype = tf.float32, name = 'predictions') result = tf.cond(tf.convert_to_tensor(False), lambda: t1, lambda: t2)
In this example, we will get t2.
In order to how to use python lambda, you can refer:
Understand Python Lambda Function for Beginners – Python Tutorial
pred = python bool?
Look at code below:
result = tf.cond(True, lambda: t1, lambda: t2)
pred = True, which is a python boolean variable, t1 will be returned?
Run this code, you will get an error:
TypeError: pred must not be a Python bool
It means we can not set pred to be a python bool, we should convert it to a tensor.
The rank of pred
pred is a tensor, the rank of it must be 0.
Look at this code:
t3 = tf.Variable(np.array([[200, 4, 5], [20, 5, 70]]), dtype = tf.float32) t4 = tf.Variable(np.array([[1, 4, 5], [1, 5, 70]]), dtype = tf.float32) result = tf.cond(t4>t3, lambda: t1, lambda: t2)
Run this code, you will get this error:
ValueError: Shape must be rank 0 but is rank 2 for ‘cond/Switch’ (op: ‘Switch’) with input shapes: [2,3], [2,3].
How to use tf.cond()?
We will use an example to show you how to do.
Look at this example:
import tensorflow as tf import numpy as np t1 = tf.Variable(np.array([[[200, 4, 5], [20, 5, 70]],[[2, 3, 5], [5, 5, 7]]]), dtype = tf.float32, name = 'lables') def addT(a, b): return a+2*b def subT(a, b): return a-3*b t2 = tf.Variable(np.array([1,2]), dtype = tf.float32, name = 'predictions') result = tf.cond(tf.less(3,4), addT(t1,t1), subT(t2, t2)) init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) print(sess.run(result))
addT(t1, t1) can be run? Run this code, you will get this error:
TypeError: true_fn must be callable.
You should modify code to:
result = tf.cond(tf.less(3,4), lambda: addT(t1,t1), lambda: subT(t2, t2))
Run this code, you will find the result is:
[[[600. 12. 15.] [ 60. 15. 210.]] [[ 6. 9. 15.] [ 15. 15. 21.]]]