Supervised Contrastive Loss is widely used in text and image classification. In this tutorial, we will introduce you how to create it by pytorch.
Supervised Contrastive Loss
We can define this loss as follows:
The main idea of contrastive learning is to maximize the consistency between pairs of positive samples andthe difference between pairs of negative samples.
Supervised Contrastive Loss in a Training Batch
We usually train a model with some batches. For example, if we plan to train a text classification model, we may input 64 documents to a model for training.
As to a batch, there exists some samples with the same class.
For example:
Document | Label |
D1 | 0 |
D2 | 1 |
D3 | 2 |
D4 | 1 |
D5 | 1 |
There are 5 documents in this batch. As to D2, D4 and D5 are its positive samples, however, D1 and D3 are negative samples.
Based this batch, we can build a supervised contrastive loss for it.
Implement Supervised Contrastive Loss in PyTorch
Here we will create a pytorch model to implement this loss.
For example:
import torch, math import torch.nn as nn import torch.nn.functional as F class ContrastiveLoss(nn.Module): def __init__(self, r = 0.07): super(ContrastiveLoss, self).__init__() self.r = r self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors #x:64*200, y:64*200 def cosine_similarity(self, x, y): cosine = F.linear(F.normalize(x, dim = 1), F.normalize(y, dim = 1)) return cosine # output: 64*200 def forward(self, outputs, targets): cosine = F.linear(outputs, outputs) dot_product_tempered = cosine / self.r exp_dot_tempered = ( torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0].detach()) + 1e-8 ) mask_similar_class = (targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets) mask_anchor_out = (1 - torch.eye(exp_dot_tempered.shape[0])).to(self.device) mask_combined = mask_similar_class * mask_anchor_out # cardinality_per_samples = torch.sum(mask_combined, dim=1) # log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True))) supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / (cardinality_per_samples+1e-8) supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample) print(supervised_contrastive_loss) return supervised_contrastive_loss
We can use it as follows:
if __name__ == "__main__": x = torch.randn(5, 192) ce = ContrastiveLoss() label = torch.LongTensor([0, 3, 9, 3, 1]) l = ce(x, label) print(l)
Run this code, we will get the loss value:
tensor(0.5545)