Understand torch.unsqueeze() with Examples – PyTorch Tutorial

By | April 26, 2022

In this tutorial, we will use some examples to show you how to use pytorch torch.unsqueeze() correctly.

torch.unsqueeze()

It is defined as:

torch.unsqueeze(input, dim)

It will insert a dimension at dim.

How to use it?

We will use some examples to show you how to use. For example:

import torch

x = torch.randn(3, 4)
print(x.shape)
x1 = torch.unsqueeze(x, dim = 1)
print(x1.shape)


x2 = torch.unsqueeze(x, dim = 0)
print(x2.shape)

Run this code, we will see:

torch.Size([3, 4])
torch.Size([3, 1, 4])
torch.Size([1, 3, 4])

x is [3, 4],

When dim = 1, we will inser a dim on dim = 1, we will get x1 with [3, 1, 4]

When dim = 0, we will inser a dim on dim = 0, we will get x1 with [1, 3, 4]

This function is same to tensorflow tf.expand_dims(), you can learn more in this tutorial:

Understand tf.expand_dims() with Examples – TensorFlow Tutorial

Leave a Reply