In pytorch, we can use torch.nn.functional.one_hot() to create one hot embeddings, which is very useful in classification problem. In this tutorial, we will introduce how to use it to create.
It is defined as:
torch.nn.functional.one_hot(tensor, num_classes=- 1)
Here tensor parameter must be LongTensor. It contains the index of 1 in each one hot embedding.
For example:
import torch import numpy as np label = torch.LongTensor(np.array([2,9,1,3,0,3])) print(label) class_num = 10 #0-10 one_hot = torch.nn.functional.one_hot(label, class_num) print(one_hot)
Here label is LongTensor, which determine the position of 1 in one_hot.
Run this code, we will see:
Here we also can find: the label is (6,) it will create a one hot embedding with shape 6*10 (class_num = 10)
If the label is (2, 3)?
label = torch.LongTensor(np.array([[2,9,1],[3,0,3]])) print(label) print(label.shape) class_num = 10 #0-10 one_hot = torch.nn.functional.one_hot(label, class_num) print(one_hot) print(one_hot.shape)
We will get a one hot as follows:
torch.Size([2, 3]) tensor([[[0, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]], [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]]]) torch.Size([2, 3, 10])
The one hot is (2, 3, 10)