Tutorial Example

Use If Condition Statement in TensorFlow – TensorFlow Tutorial

In python, we can use if statement to control the progress of application. We also can use if statement in tensorflow. In this tutorial, we will discuss this topic.

For example, we have used a python boolean variable to control whether we reverse a sequence or not in bilstm model.

if revers:
    if sequence_length is not None:
        inputs = tf.reverse_sequence(inputs, seq_lengths=sequence_length, seq_axis = 1, batch_axis = 0)
    else:
        inputs = tf.reverse(inputs, axis = [1])

Build a Custom BiLSTM Model Using TensorFlow: A Step Guide – TensorFlow Tutorial

However, it is not a good idea to use python boolean variable to control the flow of tensorflow application.

Look at the example code below:

import tensorflow as tf
import numpy as np

class HBiLSTMG(object):

    def __init__(self, w):
        self.w = w
        
        self.c = 2*self.w
        
        self.training = True
        if self.training:
            self.c = self.c+2
    
  
v1 = tf.Variable(np.array([2, 6, 3, 3, 4, 2, 4, 3]), dtype = tf.float32, name='w1')  
v2= tf.reshape(v1, [4,-1])

with tf.Session() as sess:  
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    np.set_printoptions(precision=3, suppress=True)
    
    hm = HBiLSTMG(v2)
    
    print(sess.run(hm.c))
    hm.training = False
    print(hm.training)
    print(sess.run(hm.c))

What is the value of hm.c based on hm.training?

Run this code, you will find:

When hm.training = True, the value of hm.c is:

[[ 6. 14.]
 [ 8.  8.]
 [10.  6.]
 [10.  8.]]

When hm.training = False, the value of hm.c also is:

[[ 6. 14.]
 [ 8.  8.]
 [10.  6.]
 [10.  8.]]

Why does the value of hm.c is not changed?

Because when we run hm = HBiLSTMG(v2). TensorFlow will build an operation graph based on training = True, this graph will not be changed even if the training = False. It means hm.c will also not be changed.

How to implement if statement to control the compute graph in tensorflow?

A good way is to use tf.placeholder to feed in a python boolean variable, then use tf.cond() to control.

Modify the example above as follows:

import tensorflow as tf
import numpy as np


class HBiLSTMG(object):

    def __init__(self, w):
        self.w = w
        
        self.c = 2*self.w
        
        self.training = tf.placeholder(tf.bool, [], name="is_traing")
        
        def training():
            return self.c+2
        def untraining():
            return self.c
        self.c = tf.cond(self.training, lambda: training(), lambda: self.c)
        
  
v1 = tf.Variable(np.array([2, 6, 3, 3, 4, 2, 4, 3]), dtype = tf.float32, name='w1')  
v2= tf.reshape(v1, [4,-1])

with tf.Session() as sess:  
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    np.set_printoptions(precision=3, suppress=True)
    
    hm = HBiLSTMG(v2)
    feed_dict = {hm.training: True}
    c = sess.run([hm.c], feed_dict)
    print(c)
    feed_dict = {hm.training: False}
    c = sess.run([hm.c], feed_dict)
    print(c)

We use tf.placeholder to feed a python boolean variable.

feed_dict = {hm.training: True}

Run this code, we will get the result:

[array([[ 6., 14.],
       [ 8.,  8.],
       [10.,  6.],
       [10.,  8.]], dtype=float32)]
[array([[ 4., 12.],
       [ 6.,  6.],
       [ 8.,  4.],
       [ 8.,  6.]], dtype=float32)]

We can find the value of hm.c is changed.

In order to learn how to use tf.cond() function, you can refer:

Understand tf.cond(): Run TensorFlow Function By True Condition