In pytorch, we can use torch.nn.functional.cross_entropy() to compute the cross entropy loss between inputs and targets. In this tutorial, we will introduce how to use it.
Cross Entropy Loss
It is defined as:
This loss often be used in classification problem. The gradient of this loss is here:
Understand the Gradient of Cross Entropy Loss Function – Machine Learning Tutorial
F.cross_entropy()
It is defined as:
- torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)
input: the input tensor, it can be (N,C) , we should notice: this input tensor is not computed by softmax() function.
target: it usually be an one hot embedding. the shape of it is (N,)
In order to create one hot embedding, you can view:
Understand torch.nn.functional.one_hot() with Examples – PyTorch Tutorial
Here N = batch size.
reduction: it can be none, mean and sum. It determines how to return the loss value. mean is default value.
How to use F.cross_entropy()?
First, we should import some libraries.
- import torch
- import torch.nn.functional as F
- import numpy as np
Then we should create an input and target.
- input = torch.randn(3, 5, requires_grad=True)
- print(input)
- target = torch.randint(5, (3,), dtype=torch.int64)
- print(target)
Here batch size = 3.
Run this code, we will see:
- tensor([[ 1.0491, 0.3516, -1.2480, 0.4829, -1.2766],
- [ 1.0018, -0.7298, 0.8515, -0.5951, 1.3111],
- [ 0.3726, -0.6266, -1.1043, 0.6200, 0.3317]], requires_grad=True)
- tensor([0, 4, 3])
Then we can start to compute cross entropy loss.
- loss = F.cross_entropy(input, target)
- print(loss)
- loss = F.cross_entropy(input, target, reduction='sum')
- print(loss)
- loss = F.cross_entropy(input, target, reduction='none')
- print(loss)
Run this code, we will see:
- tensor(0.9622, grad_fn=<NllLossBackward>)
- tensor(2.8866, grad_fn=<NllLossBackward>)
- tensor([0.8170, 0.9723, 1.0973], grad_fn=<NllLossBackward>)
We should notice:
When reduction = none ,the cross entropy loss of each sample will be returned.