In this tutorial, we will use some examples to show you how to use pytorch torch.unsqueeze() correctly.
torch.unsqueeze()
It is defined as:
torch.unsqueeze(input, dim)
It will insert a dimension at dim.
How to use it?
We will use some examples to show you how to use. For example:
import torch x = torch.randn(3, 4) print(x.shape) x1 = torch.unsqueeze(x, dim = 1) print(x1.shape) x2 = torch.unsqueeze(x, dim = 0) print(x2.shape)
Run this code, we will see:
torch.Size([3, 4]) torch.Size([3, 1, 4]) torch.Size([1, 3, 4])
x is [3, 4],
When dim = 1, we will inser a dim on dim = 1, we will get x1 with [3, 1, 4]
When dim = 0, we will inser a dim on dim = 0, we will get x1 with [1, 3, 4]
This function is same to tensorflow tf.expand_dims(), you can learn more in this tutorial:
Understand tf.expand_dims() with Examples – TensorFlow Tutorial