Center Loss is proposed in paper “A Discriminative Feature Learning Approach for Deep Face Recognition”, it can be used for text classification in nlp and reduce the differences among documents in the same category. In this tutorial, we will introduce how to implement it using tensorflow.
Center Loss
Center Loss is defined as:
As to text classification, \(x_i\) is the document vector and \(c_{yi}\) is the center vector of class \(y_i\). Here \(y_i\) is the label index of \(x_i\) in all class labels.
To minimize the \(L_C\), we will get a center matrix for all class labels. For example:
How to update \(c_{yi}\)?
In order to update when training, we can compute its gradient and update.
How to implement center loss in tensorflow?
We will create a function to implement this center loss. Here is the example:
import tensorflow as tf import numpy as np # features : 64 * 200, 2 dims # labels: 64* 10, 2 dims def center_loss(features, labels, alpha = 0.0): features_dim = features.get_shape()[1] label_num = labels.get_shape()[1] # create a center matrix centers = tf.get_variable('doc_centers', [label_num, features_dim], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) # label = tf.argmax(labels, 1) label = tf.reshape(label, [-1]) #[64] centers_batch = tf.gather(centers, label) diff = (1 - alpha) * (centers_batch - features) unique_label, unique_idx, unique_count = tf.unique_with_counts(label) appear_times = tf.gather(unique_count, unique_idx) appear_times = tf.reshape(appear_times, [-1, 1]) diff = diff / tf.cast((1 + appear_times), tf.float32) diff = alpha * diff centers_update_op = tf.scatter_sub(centers, label, diff) loss = tf.nn.l2_loss(features - centers_batch) return loss, centers, centers_update_op
We will explain some parameters:
features: the document vectors, 2 dimensions, for example: 64 * 200, 64 is the batch size.
labels: the one-hot vectors of labels, 2 dimensions, for example 64 * 10, 64 is the batch size.
alpha: control the speed of updating center.
Then we can use this function as follows:
self.center_loss, self.doc_center, self.centers_update_op = lossutil.center_loss(features = self.doc_output, labels = self.input_y) self.loss = loss_softmax + 0.001 * self.center_loss
To understand how to use tf.unique_with_counts(), you can read:
tf.unique_with_counts(): Count the Number of Each Element in Tensor – TensorFlow Tutorial
How to train center loss update operation?
In order to update center weights, we should update it. In tensorflow, we can do as follows:
optimizer = tf.train.AdamOptimizer(0.001) with tf.control_dependencies([self.centers_update_op]): train_op = optimizer.minimize(self.loss, global_step=global_step)
We also can add self.centers_update_op to tf.GraphKeys.UPDATE_OPS to upate.
Here is the tutorial:
Add Tensor Update Operation to tf.GraphKeys.UPDATE_OPS – TensorFlow Tutorial