Focal loss is one of method to process imbalance dataset in deep learning. In this tutorial, we will introduce how to implement focal loss for multi label classification in pytorch. We also implement it in tensorflow.
Implement Focal Loss for Multi Label Classification in TensorFlow
Here is a focal loss function example:
import torch import torch.nn.functional as F import torch.nn as nn class FocalLoss(nn.Module): def __init__(self, alpha=1.0, gamma=2.0): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma def forward(self, inputs, targets): ''' :param inputs: batch_size * dim :param targets: (batch,) :return: ''' bce_loss = F.cross_entropy(inputs, targets) loss = self.alpha * (1 - torch.exp(-bce_loss)) ** self.gamma * bce_loss return loss
Here we use F.cross_entropy() to create focal loss.
Understand F.cross_entropy(): Compute The Cross Entropy Loss – PyTorch Tutorial
Then we can compute it as follows:
input = torch.randn(3, 5, requires_grad=True) print(input) target = torch.randint(5, (3,), dtype=torch.int64) print(target) focalloss = FocalLoss() loss = focalloss(input, target) print(loss)
In this example, the batch size = 3. Run this code, we will get a focal loss.
tensor([[-0.5062, -1.3323, -0.2291, 0.8467, 0.2213], [-0.2047, -0.6323, -0.4368, 0.5923, -0.0875], [ 0.0464, -1.8035, 0.5451, -0.2998, -0.1924]], requires_grad=True) tensor([0, 4, 0]) tensor(1.1985, grad_fn=<MulBackward0>)