Tutorial Example

Understand torch.nn.functional.pad() with Examples – PyTorch Tutorial

In pytorch, torch.nn.functional.pad() can allow us to pad a tensor easily. In this tutorial, we will introduce you how to use it with some examples.

Syntax

torch.nn.functional.pad() is defined as:

torch.nn.functional.pad(input, pad, mode='constant', value=None)

Here

input: tensor will be padded.

pad: it is a tuple, which contains m-elements. It determines how to pad a tensor.

mode: ‘constant’, ‘reflect’, ‘replicate’ or ‘circular’. Default: ‘constant’

value: fill value for ‘constant’ padding. Default: 0. We should notice value only work when mode = “constant”

How to pad a tensor based on pad parameter?

We should know:

For example:

import torch
import torch.nn.functional as F

x = torch.zeros(3, 3, 4, 2)
p1d = (1, 1) # pad last dim by 1 on each side
out = F.pad(x, p1d, "constant", 3)  # effectively three padding


print(x.shape)
print(out.shape)
print("x=",x)
print("out=",out)

Here p1d = (1, 1), it contains two elements, x will be padded at the last dim.

The shape of x is (3, 3, 4, 2), the last dim is 2. p1d = (1, 1), it means the shape of out will be (3, 3, 4, 2+1+1) = (3, 3, 4, 4)

Run this code we will see:

torch.Size([3, 3, 4, 2])
torch.Size([3, 3, 4, 4])
out= tensor([[[[3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.]],

         [[3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.]],

         [[3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.]]],


        [[[3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.]],

         [[3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.]],

         [[3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.]]],


        [[[3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.]],

         [[3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.]],

         [[3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.],
          [3., 0., 0., 3.]]]])

We can find the left and right are padded with one value 3.

If p1d = (3, 1)

The shape of out will be (3, 3, 4, 2+3+1) = (3, 3, 4, 6). The left will be padded with [3,3,3]. The right will be padded with one [3]

For example:

How about p1d = (3, 1, 2, 3)?

p1d contains 4 elements, it means the last 2 dim will be padded.

The shape of x is (3, 3, 4, 2), the shape of out will be (3, 3, 4+2+3, 2+3+1) = (3, 3, 9, 6)

For example:

torch.Size([3, 3, 4, 2])
torch.Size([3, 3, 9, 6])
out= tensor([[[[3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.]],
....
         [[3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.]]]])

Process finished with exit code 0

As to the second last dim, it will be padded with (2, 3), it means the top will be padded 2 dim, the bottom will be padded with 3 dim.

How about p1d = (3, 1, 2, 3, 1, 2)?

p1d contains 6 elements, it means the last three dim will be padded.

The shape of x is (3, 3, 4, 2), the shape of out will be (3, 3+1+2, 4+2+3, 2+3+1) = (3, 6, 9, 6)

For example:

x = torch.zeros(3, 3, 4, 2)
p1d = (3, 1, 2, 3) # pad last dim by 1 on each side
out = F.pad(x, p1d, "constant", 3)  # effectively three padding


print(x.shape)
print(out.shape)
print("out=",out)

Run this code, we will see:

torch.Size([3, 3, 4, 2])
torch.Size([3, 6, 9, 6])
out= tensor([[[[3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.]],

         [[3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.]],

         [[3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.]],

        ......

         [[3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 0., 0., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.]],

         [[3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.]],

         [[3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.],
          [3., 3., 3., 3., 3., 3.]]]])