In this tutorial, we will use an example to help you understand tf.clip_by_global_norm() correctly.
tf.clip_by_global_norm()
This function is defined as:
tf.clip_by_global_norm( t_list, clip_norm, use_norm=None, name=None )
It will clip values of multiple tensors t_list by the ratio of the sum of their norms.
Here t_list[i] is computed by:
t_list[i] * clip_norm / max(global_norm, clip_norm)
where:
global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))
From above, we can find:
If clip_norm > global_norm then the entries in t_list remain as they are.
It will return two values:
list_clipped and global_normĀ
How to use tf.clip_by_global_norm() in TensorFlow?
Look at example below:
import tensorflow as tf import numpy as np w = tf.get_variable('weight', shape=[5, 5], dtype = tf.float32) w_n = tf.clip_by_global_norm([w], 3.0) init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) np.set_printoptions(precision=4, suppress=True) w1, w2 = sess.run([w, w_n]) print(w1) print(w2)
Run this code, we will see:
[[-0.1529 -0.3025 -0.2745 -0.7379 0.6346] [-0.7407 0.5338 -0.2375 -0.7452 0.4232] [ 0.7719 0.5403 0.4836 -0.7039 0.5534] [-0.7702 0.0088 0.5766 0.2401 -0.5267] [ 0.7486 0.7236 -0.1808 -0.7242 0.7003]] ([array([[-0.1529, -0.3025, -0.2745, -0.7379, 0.6346], [-0.7407, 0.5338, -0.2375, -0.7452, 0.4232], [ 0.7719, 0.5403, 0.4836, -0.7039, 0.5534], [-0.7702, 0.0088, 0.5766, 0.2401, -0.5267], [ 0.7486, 0.7236, -0.1808, -0.7242, 0.7003]], dtype=float32)], 2.8421102)
Here global_norm = 2.8421102 < 3.0, which means weight w is not normalized.
If global_norm > 3.0?
w = w + 1.0 w_n = tf.clip_by_global_norm([w], 3.0) init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) np.set_printoptions(precision=4, suppress=True) w1, w2 = sess.run([w, w_n]) print(w1) print(w2)
Here we add w using 1, run this code, we will see:
[[0.8825 1.3536 0.7567 1.3894 1.0013] [1.7247 0.8422 0.9118 0.918 0.6208] [1.5703 0.7254 1.5474 0.9198 0.8047] [1.169 1.5639 1.6038 1.0512 1.5293] [0.7301 0.4894 0.5306 1.0528 0.3921]] ([array([[0.4764, 0.7307, 0.4085, 0.75 , 0.5405], [0.931 , 0.4546, 0.4922, 0.4956, 0.3351], [0.8477, 0.3916, 0.8353, 0.4965, 0.4344], [0.6311, 0.8442, 0.8658, 0.5675, 0.8255], [0.3941, 0.2642, 0.2864, 0.5684, 0.2116]], dtype=float32)], 5.557393)
Here global_norm = 5.557393 > 3.0, it means w wil be normalized.