Understand torch.nn.functional.one_hot() with Examples – PyTorch Tutorial

By | July 12, 2022

In pytorch, we can use torch.nn.functional.one_hot() to create one hot embeddings, which is very useful in classification problem. In this tutorial, we will introduce how to use it to create.

python one-hot encoding

torch.nn.functional.one_hot()

It is defined as:

torch.nn.functional.one_hot(tensor, num_classes=- 1)

Here tensor parameter must be LongTensor. It contains the index of 1 in each one hot embedding.

For example:

import torch
import numpy as np

label = torch.LongTensor(np.array([2,9,1,3,0,3]))
print(label)
class_num = 10 #0-10

one_hot = torch.nn.functional.one_hot(label, class_num)
print(one_hot)

Here label is LongTensor, which determine the position of 1 in one_hot.

Run this code, we will see:

Understand torch.nn.functional.one_hot() with Examples - PyTorch Tutorial

Here we also can find: the label is (6,) it will create a one hot embedding with shape 6*10 (class_num  = 10)

If the label is (2, 3)?

label = torch.LongTensor(np.array([[2,9,1],[3,0,3]]))
print(label)
print(label.shape)
class_num = 10 #0-10

one_hot = torch.nn.functional.one_hot(label, class_num)
print(one_hot)
print(one_hot.shape)

We will get a one hot as follows:

torch.Size([2, 3])
tensor([[[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
         [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]]])
torch.Size([2, 3, 10])

The one hot is (2, 3, 10)

Leave a Reply