Implement Focal Loss for Multi Label Classification in PyTorch – PyTorch Tutorial

By | July 12, 2022

Focal loss is one of method to process imbalance dataset in deep learning. In this tutorial, we will introduce how to implement focal loss for multi label classification in pytorch. We also implement it in tensorflow.

Implement Focal Loss for Multi Label Classification in TensorFlow

Here is a focal loss function example:

import torch
import torch.nn.functional as F
import torch.nn as nn

class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        '''
        :param inputs: batch_size * dim
        :param targets: (batch,)
        :return:
        '''
        bce_loss = F.cross_entropy(inputs, targets)
        loss = self.alpha * (1 - torch.exp(-bce_loss)) ** self.gamma * bce_loss
        return loss

Here we use F.cross_entropy() to create focal loss.

Understand F.cross_entropy(): Compute The Cross Entropy Loss – PyTorch Tutorial

Then we can compute it as follows:

input = torch.randn(3, 5, requires_grad=True)
print(input)
target = torch.randint(5, (3,), dtype=torch.int64)
print(target)

focalloss = FocalLoss()
loss = focalloss(input, target)
print(loss)

In this example, the batch size = 3. Run this code, we will get a focal loss.

tensor([[-0.5062, -1.3323, -0.2291,  0.8467,  0.2213],
        [-0.2047, -0.6323, -0.4368,  0.5923, -0.0875],
        [ 0.0464, -1.8035,  0.5451, -0.2998, -0.1924]], requires_grad=True)
tensor([0, 4, 0])
tensor(1.1985, grad_fn=<MulBackward0>)

Leave a Reply