List All Trainable and Untrainable Variables in TensorFlow – TensorFlow Tutorial

By | June 12, 2019

Trainbale variables is very important in tensorflow, we should use them to regularize our model.

How to List all trainable variables in tensorflow?

How to List all untrainable and trainbale variables in tensorflow?

This example will help you.

Step 1. Define four tensorflow tensor.

w_untrain = tf.Variable(tf.random_uniform([2,3], -1, 1), trainable=False, name='w_untrain')

w_train = tf.Variable(tf.random_uniform([2,3], -1, 1), name='w_train')

c_train = tf.constant(np.random.uniform(-1,1,[2,3]), dtype=tf.float32, name='c_train')

input_x = tf.placeholder(tf.int32, [None, 40, 50], name="input_x")

Step 2. Use tf.trainable_variables() and tf.global_variables() to get all trainable and untrainable variables

Here is full code.

import tensorflow as tf;
import numpy as np

w_untrain = tf.Variable(tf.random_uniform([2,3], -1, 1), trainable=False, name='w_untrain')

w_train = tf.Variable(tf.random_uniform([2,3], -1, 1), name='w_train')

c_train = tf.constant(np.random.uniform(-1,1,[2,3]), dtype=tf.float32, name='c_train')

input_x = tf.placeholder(tf.int32, [None, 40, 50], name="input_x")

init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    #get all trainable variables
    print 'all trainable variables'
    v = [n.name for n in tf.trainable_variables()]
    for vv in v:
        print vv
    #get all trainable and untrainable variables
    print 'all trainable and untrainable variables'
    v = [n.name for n in tf.global_variables()]
    for vv in v:
        print vv

The output is

Leave a Reply