Tutorial Example

Understand tf.cond(): Run TensorFlow Function By True Condition – TensorFlow Tutorial

In this tutorial, we will introduce how to run tensorflow function by condition. It is very useful to build tensorflow model.

For example, we may need this kind of logical application.

flag = True
if flag:
    run_fun_1()
else:
    run_fun_2()

function run_fun_1 or run_fun_12 are determined by flag.

In tensorflow, we can use tf.cond() to implement the some function.

Syntax

tf.cond(
    pred,
    true_fn=None,
    false_fn=None,
    strict=False,
    name=None,
    fn1=None,
    fn2=None
)

Return true_fn() if the predicate pred is true else false_fn()

It means:

pred = True, run true_fn and return its value

pred = False, run false_fn and return its value

We should notice:

For example:

import tensorflow as tf
import numpy as np

t1 = tf.Variable(np.array([[[200, 4, 5], [20, 5, 70]],[[2, 3, 5], [5, 5, 7]]]), dtype = tf.float32, name = 'lables')

t2 = tf.Variable(np.array([1,2]), dtype = tf.float32, name = 'predictions')
result = tf.cond(tf.convert_to_tensor(False), lambda: t1, lambda: t2)

In this example, we will get t2.

In order to how to use python lambda, you can refer:

Understand Python Lambda Function for Beginners – Python Tutorial

pred = python bool?

Look at code below:

result = tf.cond(True, lambda: t1, lambda: t2)

pred = True, which is a python boolean variable,  t1 will be returned?

Run this code, you will get an error:

TypeError: pred must not be a Python bool

It means we can not set pred to be a python bool, we should convert it to a tensor.

The rank of pred

pred is a tensor, the rank of it must be 0.

Look at this code:

t3 = tf.Variable(np.array([[200, 4, 5], [20, 5, 70]]), dtype = tf.float32)
t4 = tf.Variable(np.array([[1, 4, 5], [1, 5, 70]]), dtype = tf.float32)
                 
result = tf.cond(t4>t3, lambda: t1, lambda: t2)

Run this code, you will get this error:

ValueError: Shape must be rank 0 but is rank 2 for ‘cond/Switch’ (op: ‘Switch’) with input shapes: [2,3], [2,3].

How to use tf.cond()?

We will use an example to show you how to do.

Look at this example:

import tensorflow as tf
import numpy as np

t1 = tf.Variable(np.array([[[200, 4, 5], [20, 5, 70]],[[2, 3, 5], [5, 5, 7]]]), dtype = tf.float32, name = 'lables')

def addT(a, b):
    return a+2*b
def subT(a, b):
    return a-3*b

t2 = tf.Variable(np.array([1,2]), dtype = tf.float32, name = 'predictions')

result = tf.cond(tf.less(3,4), addT(t1,t1), subT(t2, t2))
init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    print(sess.run(result))

addT(t1, t1) can be run? Run this code, you will get this error:

TypeError: true_fn must be callable.

You should modify code to:

result = tf.cond(tf.less(3,4), lambda: addT(t1,t1), lambda: subT(t2, t2))

Run this code, you will find the result is:

[[[600.  12.  15.]
  [ 60.  15. 210.]]

 [[  6.   9.  15.]
  [ 15.  15.  21.]]]