tf.Variable() is often used to create a tensorflow variable (tensor) in tensorflow application. In this tutorial, we will discuss how to use it correctly for tensorflow beginners.
Syntax of tf.Variable()
We should notice tf.Variable() is a python class, not a python function.
__init__( initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None, dtype=None, expected_shape=None, import_scope=None, constraint=None )
Create a new variable with value initial_value
Parameters explained
initial_value: the initial value of new variable
trainable: make this variable can be traind by model, if you set False, the value of varaible can not be modified when minimizing the loss.
name: variable name, you should set a name for variable in order to you can get the variable value by its name. This is very useful when you load a existing model.
dtype: the data type of value, it can be tf.float32 or tf.int32
Notice:
1. if you use tf.Variable(), it will create a new variable no matter reuse in tf.variable_scope().
2. use tf.Variable() to create a new variable, this varialble will be added to tf.GraphKeys.GLOBAL_VARIABLES. If you set trainable = True, it will also be added to tf.GraphKeys.TRAINABLE_VARIABLES. If you set trainable = False, it only be added to tf.GraphKeys.GLOBAL_VARIABLES.
Here is an example:
import tensorflow as tf import numpy as np w1 = tf.Variable(tf.random_normal(shape=[2,2], mean=0, stddev=1), name='w') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print("w1 = ") print(w1.name) print(w1.eval()) print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) ) print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) )
In this example, we use tf.Variable() create a new tensorflow varialbe and set trainable = True
Run this script, we will find:
w1 = w:0 [<tf.Variable 'w:0' shape=(2, 2) dtype=float32_ref>] [<tf.Variable 'w:0' shape=(2, 2) dtype=float32_ref>]
The name of new variable is ‘w:0‘, it is added to tf.GraphKeys.GLOBAL_VARIABLES and tf.GraphKeys.TRAINABLE_VARIABLES graph.
If we set trainable = False
w1 = tf.Variable(tf.random_normal(shape=[2,2], mean=0, stddev=1), name='w', trainable = False)
Then print this variable
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) print("w1 = ") print(w1.name) print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) ) print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) ) print(tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES))
Run this code and we will find:
w1 = w:0 [<tf.Variable 'w:0' shape=(2, 2) dtype=float32_ref>] [] []
This variable is only added to tf.GraphKeys.GLOBAL_VARIABLES, not in tf.GraphKeys.TRAINABLE_VARIABLES and tf.GraphKeys.LOCAL_VARIABLES.
Why we use tf.Variable() to create new variable
Variables created by tf.Variable() are often used to store weights or bias in deep learning model. They can be modified by minimizing model loss function. They also can be saved into a file.
Examples
Here are some examples to create a new variable.
w = tf.Variable(tf.random_normal(shape=[2,3], mean=0, stddev=1), name='w')
Create a 2 *3 matrix tensor.
b = tf.Variable(np.arange(3,3), dtype=tf.float32)
Create a 3 * 3 matrix tensor with numpy.