In pytorch, torch.nn.functional.pad() can allow us to pad a tensor easily. In this tutorial, we will introduce you how to use it with some examples.
Syntax
torch.nn.functional.pad() is defined as:
torch.nn.functional.pad(input, pad, mode='constant', value=None)
Here
input: tensor will be padded.
pad: it is a tuple, which contains m-elements. It determines how to pad a tensor.
mode: ‘constant’, ‘reflect’, ‘replicate’ or ‘circular’. Default: ‘constant’
value: fill value for ‘constant’ padding. Default: 0. We should notice value only work when mode = “constant”
How to pad a tensor based on pad parameter?
We should know:
- pad = (A,B). It contains two elements, the last dim will be padded.
- pad = (A,B, C, D). It contains four elements, the last two dim will be padded.
- pad = (A, B, C, D, E, F). It contains six elements, the last three dim will be padded.
For example:
import torch import torch.nn.functional as F x = torch.zeros(3, 3, 4, 2) p1d = (1, 1) # pad last dim by 1 on each side out = F.pad(x, p1d, "constant", 3) # effectively three padding print(x.shape) print(out.shape) print("x=",x) print("out=",out)
Here p1d = (1, 1), it contains two elements, x will be padded at the last dim.
The shape of x is (3, 3, 4, 2), the last dim is 2. p1d = (1, 1), it means the shape of out will be (3, 3, 4, 2+1+1) = (3, 3, 4, 4)
Run this code we will see:
torch.Size([3, 3, 4, 2]) torch.Size([3, 3, 4, 4]) out= tensor([[[[3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.]], [[3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.]], [[3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.]]], [[[3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.]], [[3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.]], [[3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.]]], [[[3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.]], [[3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.]], [[3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.], [3., 0., 0., 3.]]]])
We can find the left and right are padded with one value 3.
If p1d = (3, 1)
The shape of out will be (3, 3, 4, 2+3+1) = (3, 3, 4, 6). The left will be padded with [3,3,3]. The right will be padded with one [3]
For example:
How about p1d = (3, 1, 2, 3)?
p1d contains 4 elements, it means the last 2 dim will be padded.
- The last dim will be padded with (3, 1)
- The second last dim will be padded with (2, 3)
The shape of x is (3, 3, 4, 2), the shape of out will be (3, 3, 4+2+3, 2+3+1) = (3, 3, 9, 6)
For example:
torch.Size([3, 3, 4, 2]) torch.Size([3, 3, 9, 6]) out= tensor([[[[3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.]], .... [[3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.]]]]) Process finished with exit code 0
As to the second last dim, it will be padded with (2, 3), it means the top will be padded 2 dim, the bottom will be padded with 3 dim.
How about p1d = (3, 1, 2, 3, 1, 2)?
p1d contains 6 elements, it means the last three dim will be padded.
- The last dim will be padded with (3, 1)
- The second last dim will be padded with (2, 3)
- The third last dim will be padded with (1, 2)
The shape of x is (3, 3, 4, 2), the shape of out will be (3, 3+1+2, 4+2+3, 2+3+1) = (3, 6, 9, 6)
For example:
x = torch.zeros(3, 3, 4, 2) p1d = (3, 1, 2, 3) # pad last dim by 1 on each side out = F.pad(x, p1d, "constant", 3) # effectively three padding print(x.shape) print(out.shape) print("out=",out)
Run this code, we will see:
torch.Size([3, 3, 4, 2]) torch.Size([3, 6, 9, 6]) out= tensor([[[[3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.]], [[3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.]], [[3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.]], ...... [[3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 0., 0., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.]], [[3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.]], [[3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.], [3., 3., 3., 3., 3., 3.]]]])