In this tutorial, we will introduce how to apply gradient clipping in tensorflow. It is very useful to make your model stable.
Step 1: create a optimizer with a learning rate
For example:
def optim(lr): """ return optimizer determined by configuration :return: tf optimizer """ if config.optim == "sgd": return tf.train.GradientDescentOptimizer(lr) elif config.optim == "rmsprop": return tf.train.RMSPropOptimizer(lr) elif config.optim == "adam": return tf.train.AdamOptimizer(lr, beta1=config.beta1, beta2=config.beta2) else: raise AssertionError("Wrong optimizer type!")
Then, we can create an optimizer as follows:
optimizer= optim(lr)
Here lr is the learning rate, it can be: 0.001
Step 2: clip the gradient
Here is an example code:
# optimizer operation trainable_vars= tf.trainable_variables() # get variable list grads, vars= zip(*optimizer.compute_gradients(loss)) # compute gradients of variables with respect to loss grads_clip, _ = tf.clip_by_global_norm(grads, 3.0) # l2 norm clipping by 3
In this code, we will use zip(*) function to get all variables and gradient. Then ,we will use tf.clip_by_global_norm() to clip it.
Understand tf.clip_by_global_norm(): Clip Values of Tensors – TensorFlow Tutorial
To understand zip(*), you can view:
Understand Python zip(*): Unzipping a Sequence with Examples – Python Tutorial
Step 3: rescale the gradient of some variables
We also can rescale the gradient of some variables in tensorflow, here is an example:
grads_rescale= [0.01*grad for grad in grads_clip[:2]] + grads_clip[2:] # smaller gradient scale for w, b
You should find which variables should be rescale.
Step 4: train model by loss function
Finally, we can start to train model by loss function.
global_step = tf.Variable(0, name='global_step', trainable=False) with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): train_op= optimizer.apply_gradients(zip(grads_rescale, vars), global_step= global_step) # gradient update operation # record loss loss_summary = tf.summary.scalar("loss", loss) merged = tf.summary.merge_all() saver = tf.train.Saver() # training session with tf.Session() as sess: tf.global_variables_initializer().run() os.makedirs(os.path.join(path, "Check_Point"), exist_ok=True) # make folder to save model os.makedirs(os.path.join(path, "logs"), exist_ok=True) # make folder to save log writer = tf.summary.FileWriter(os.path.join(path, "logs"), sess.graph) epoch = 0 lr_factor = 1 # lr decay factor ( 1/2 per 10000 iteration) loss_acc = 0 # accumulated loss ( for running average of loss) for iter in range(config.iteration): # run forward and backward propagation and update parameters _, loss_cur, summary = sess.run([train_op, loss, merged], feed_dict={batch: random_batch(), lr: config.lr*lr_factor})
Here we will use sess.run() to start to train. loss is the loss function.