Implement Supervised Contrastive Loss in a Batch with PyTorch – PyTorch Tutorial

By | January 31, 2023

Supervised Contrastive Loss is widely used in text and image classification. In this tutorial, we will introduce you how to create it by pytorch.

Supervised Contrastive Loss

We can define this loss as follows:

Supervised Contrastive Loss

The main idea of contrastive learning is to maximize the consistency between pairs of positive samples andthe difference between pairs of negative samples.

Supervised Contrastive Loss in a Training Batch

We usually train a model with some batches. For example, if we plan to train a text classification model, we may input 64 documents to a model for training.

As to a batch, there exists some samples with the same class.

For example:

Document Label
D1 0
D2 1
D3 2
D4 1
D5 1

There are 5 documents in this batch. As to D2, D4 and D5 are its positive samples, however, D1 and D3 are negative samples.

Based this batch, we can build a supervised contrastive loss for it.

Implement Supervised Contrastive Loss in PyTorch

Here we will create a pytorch model to implement this loss.

For example:

import torch, math
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    def __init__(self, r = 0.07):
        super(ContrastiveLoss, self).__init__()
        self.r = r
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')  # sets device for model and PyTorch tensors

    #x:64*200, y:64*200
    def cosine_similarity(self, x, y):
        cosine = F.linear(F.normalize(x, dim = 1), F.normalize(y, dim = 1))
        return cosine

    # output: 64*200
    def forward(self, outputs, targets):
        cosine = F.linear(outputs, outputs)
        dot_product_tempered = cosine / self.r
        exp_dot_tempered = (
                torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0].detach()) + 1e-8
        )
        mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets)
        mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(self.device)
        mask_combined = mask_similar_class * mask_anchor_out #
        cardinality_per_samples = torch.sum(mask_combined, dim=1) #

        log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
        supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / (cardinality_per_samples+1e-8)
        supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)

        print(supervised_contrastive_loss)
        return supervised_contrastive_loss

We can use it as follows:

if __name__ == "__main__":
    x = torch.randn(5, 192)
    ce = ContrastiveLoss()

    label = torch.LongTensor([0, 3, 9, 3, 1])

    l = ce(x, label)
    print(l)

Run this code, we will get the loss value:

tensor(0.5545)