Understand tf.Variable(): A Beginner Guide – TensorFlow Tutorial

By | November 8, 2019

tf.Variable() is often used to create a tensorflow variable (tensor) in tensorflow application. In this tutorial, we will discuss how to use it correctly for tensorflow beginners.

Syntax of tf.Variable()

We should notice tf.Variable() is a python class, not a python function.

__init__(
    initial_value=None,
    trainable=True,
    collections=None,
    validate_shape=True,
    caching_device=None,
    name=None,
    variable_def=None,
    dtype=None,
    expected_shape=None,
    import_scope=None,
    constraint=None
)

Create a new variable with value initial_value

Parameters explained

initial_value: the initial value of new variable

trainable: make this variable can be traind by model, if you set False, the value of varaible can not be modified when minimizing the loss.

name: variable name, you should set a name for variable in order to you can get the variable value by its name. This is very useful when you load a existing model.

dtype: the data type of value, it can be tf.float32 or tf.int32

Notice:

1. if you use tf.Variable(), it will create a new variable no matter reuse in tf.variable_scope().

2. use tf.Variable() to create a new variable, this varialble will be added to tf.GraphKeys.GLOBAL_VARIABLES. If you set trainable = True, it will also be added to tf.GraphKeys.TRAINABLE_VARIABLES. If you set trainable = False, it only be added to tf.GraphKeys.GLOBAL_VARIABLES.

Here is an example:

import tensorflow as tf
import numpy as np   
  
w1 = tf.Variable(tf.random_normal(shape=[2,2], mean=0, stddev=1), name='w')
with tf.Session() as sess:  
    sess.run(tf.global_variables_initializer())
    print("w1 = ")
    print(w1.name)
    print(w1.eval())
    print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) )
    print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) )

In this example, we use tf.Variable() create a new tensorflow varialbe and set trainable = True

Run this script, we will find:

w1 = 
w:0
[<tf.Variable 'w:0' shape=(2, 2) dtype=float32_ref>]
[<tf.Variable 'w:0' shape=(2, 2) dtype=float32_ref>]

The name of new variable is ‘w:0‘, it is added to tf.GraphKeys.GLOBAL_VARIABLES and tf.GraphKeys.TRAINABLE_VARIABLES graph.

If we set trainable = False

w1 = tf.Variable(tf.random_normal(shape=[2,2], mean=0, stddev=1), name='w', trainable = False)

Then print this variable

with tf.Session() as sess:  
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    print("w1 = ")
    print(w1.name)
 
    print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) )
    print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) )
    print(tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES))

Run this code and we will find:

w1 = 
w:0
[<tf.Variable 'w:0' shape=(2, 2) dtype=float32_ref>]
[]
[]

This variable is only added to tf.GraphKeys.GLOBAL_VARIABLES, not in tf.GraphKeys.TRAINABLE_VARIABLES and tf.GraphKeys.LOCAL_VARIABLES.

Why we use tf.Variable() to create new variable

Variables created by tf.Variable() are often used to store weights or bias in deep learning model. They can be modified by minimizing model loss function. They also can be saved into a file.

Examples

Here are some examples to create a new variable.

w = tf.Variable(tf.random_normal(shape=[2,3], mean=0, stddev=1), name='w')

Create a 2 *3 matrix tensor.

b = tf.Variable(np.arange(3,3), dtype=tf.float32)

Create a 3 * 3 matrix tensor with numpy.

Leave a Reply