In tensorflow, we can add some update tensor operations to tf.GraphKeys.UPDATE_OPS to manage these update operations. In this tutorial, we will introduce you how to do.
Look at example below:
import tensorflow as tf x = tf.get_variable('x', [5, 10], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) y = tf.get_variable('y', [5, 10], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) diff = x - y xy = tf.convert_to_tensor(np.array([0,1,2,3,4])) z = tf.compat.v1.scatter_sub(x, xy, diff) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, z) print(tf.GraphKeys.UPDATE_OPS) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print(update_ops)
Run this code, we will get this result:
update_ops [<tf.Tensor 'ScatterSub:0' shape=(5, 10) dtype=float32_ref>]
From the result, we can find:
We can use tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, z) to add a update operation to tf.GraphKeys.UPDATE_OPS.
We can use tf.get_collection(tf.GraphKeys.UPDATE_OPS) to list all update operations in tf.GraphKeys.UPDATE_OPS.
tf.add_to_collection()
tf.add_to_collection() is defined as:
tf.add_to_collection( name, value )
As to center loss, in order to update center weights, we may do as follows:
optimizer = tf.train.AdamOptimizer(0.001) with tf.control_dependencies([self.centers_update_op]): train_op = optimizer.minimize(self.loss, global_step=global_step)
Implement Center Loss Function for Text Classification in TensorFlow – TensorFlow Tutorial
However, we also can implement it like this:
centers_update_op = tf.scatter_sub(centers, label, diff) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, centers_update_op) global_step = tf.Variable(0, name="global_step", trainable=False) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.AdamOptimizer(FLAGS.lr) grads_and_vars = optimizer.compute_gradients(mix_model.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)