Step Guide to Apply Gradient Clipping in TensorFlow – TensorFlow Tutorial

By | March 17, 2022

In this tutorial, we will introduce how to apply gradient clipping in tensorflow. It is very useful to make your model stable.

Step 1: create a optimizer with a learning rate

For example:

def optim(lr):
    """ return optimizer determined by configuration
    :return: tf optimizer
    """
    if config.optim == "sgd":
        return tf.train.GradientDescentOptimizer(lr)
    elif config.optim == "rmsprop":
        return tf.train.RMSPropOptimizer(lr)
    elif config.optim == "adam":
        return tf.train.AdamOptimizer(lr, beta1=config.beta1, beta2=config.beta2)
    else:
        raise AssertionError("Wrong optimizer type!")

Then, we can create an optimizer as follows:

optimizer= optim(lr)

Here lr is the learning rate, it can be: 0.001

Step 2: clip the gradient

Here is an example code:

    # optimizer operation
    trainable_vars= tf.trainable_variables()                # get variable list
    grads, vars= zip(*optimizer.compute_gradients(loss))    # compute gradients of variables with respect to loss
    grads_clip, _ = tf.clip_by_global_norm(grads, 3.0)      # l2 norm clipping by 3

In this code, we will use zip(*) function to get all variables and gradient. Then ,we will use tf.clip_by_global_norm() to clip it.

Understand tf.clip_by_global_norm(): Clip Values of Tensors – TensorFlow Tutorial

To understand zip(*), you can view:

Understand Python zip(*): Unzipping a Sequence with Examples – Python Tutorial

Step 3: rescale the gradient of some variables

We also can rescale the gradient of some variables in tensorflow, here is an example:

grads_rescale= [0.01*grad for grad in grads_clip[:2]] + grads_clip[2:]   # smaller gradient scale for w, b

You should find which variables should be rescale.

Step 4: train model by loss function

Finally, we can start to train model by loss function.

    global_step = tf.Variable(0, name='global_step', trainable=False)
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op= optimizer.apply_gradients(zip(grads_rescale, vars), global_step= global_step)   # gradient update operation

    # record loss
    loss_summary = tf.summary.scalar("loss", loss)
    merged = tf.summary.merge_all()
    saver = tf.train.Saver()

    # training session
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        os.makedirs(os.path.join(path, "Check_Point"), exist_ok=True)  # make folder to save model
        os.makedirs(os.path.join(path, "logs"), exist_ok=True)          # make folder to save log
        writer = tf.summary.FileWriter(os.path.join(path, "logs"), sess.graph)
        epoch = 0
        lr_factor = 1   # lr decay factor ( 1/2 per 10000 iteration)
        loss_acc = 0    # accumulated loss ( for running average of loss)

        for iter in range(config.iteration):
            # run forward and backward propagation and update parameters
            _, loss_cur, summary = sess.run([train_op, loss, merged],
                                  feed_dict={batch: random_batch(), lr: config.lr*lr_factor})

Here we will use sess.run() to start to train. loss is the loss function.

Leave a Reply