Implement Center Loss Function for Text Classification in TensorFlow – TensorFlow Tutorial

By | August 19, 2021

Center Loss is proposed in paper “A Discriminative Feature Learning Approach for Deep Face Recognition”, it can be used for text classification in nlp and reduce the differences among documents in the same category. In this tutorial, we will introduce how to implement it using tensorflow.

Center Loss

Center Loss is defined as:

understand center loss for text classification

As to text classification, \(x_i\) is the document vector and \(c_{yi}\) is the center vector of class \(y_i\). Here \(y_i\) is the label index of \(x_i\) in all class labels.

To minimize the \(L_C\), we will get a center matrix for all class labels. For example:

The effect of center loss function

How to update \(c_{yi}\)?

In order to update when training, we can compute its gradient and update.

steps to update center loss parameters

How to implement center loss in tensorflow?

We will create a function to implement this center loss. Here is the example:

import tensorflow as tf
import numpy as np

# features : 64 * 200, 2 dims
# labels: 64* 10, 2 dims
def center_loss(features, labels, alpha = 0.0):
    features_dim = features.get_shape()[1]
    label_num = labels.get_shape()[1]
    # create a center matrix
    centers = tf.get_variable('doc_centers', [label_num, features_dim], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False)
    #
    label = tf.argmax(labels, 1)
    label = tf.reshape(label, [-1]) #[64]

    centers_batch = tf.gather(centers, label)
    
    diff = (1 - alpha) * (centers_batch - features)
    

    
    unique_label, unique_idx, unique_count = tf.unique_with_counts(label)
    appear_times = tf.gather(unique_count, unique_idx)
    appear_times = tf.reshape(appear_times, [-1, 1])

    diff = diff / tf.cast((1 + appear_times), tf.float32)
    diff = alpha * diff

    centers_update_op = tf.scatter_sub(centers, label, diff)
    loss = tf.nn.l2_loss(features - centers_batch)
    return loss, centers, centers_update_op

We will explain some parameters:

features: the document vectors, 2 dimensions, for example: 64 * 200, 64 is the batch size.

labels: the one-hot vectors of labels, 2 dimensions, for example 64 * 10, 64 is the batch size.

alpha: control the speed of updating center.

Then we can use this function as follows:

self.center_loss, self.doc_center, self.centers_update_op = lossutil.center_loss(features = self.doc_output, labels = self.input_y)
self.loss = loss_softmax + 0.001 * self.center_loss

To understand how to use tf.unique_with_counts(), you can read:

tf.unique_with_counts(): Count the Number of Each Element in Tensor – TensorFlow Tutorial

How to train center loss update operation?

In order to update center weights, we should update it. In tensorflow, we can do as follows:

optimizer = tf.train.AdamOptimizer(0.001)

with tf.control_dependencies([self.centers_update_op]):
    train_op = optimizer.minimize(self.loss, global_step=global_step)

We also can add self.centers_update_op to tf.GraphKeys.UPDATE_OPS to upate.

Here is the tutorial:

Add Tensor Update Operation to tf.GraphKeys.UPDATE_OPS – TensorFlow Tutorial

Leave a Reply