Understand PyTorch Tensor.masked_fill() with Examples – PyTorch Tutorial

By | May 22, 2023

PyTorch Tensor.masked_fill() is same to Tensor.masked_fill_(mask, value). In this tutorial, we will use some examples to show you how to use it.

Syntax

Tensor.masked_fill() is defined as:

Tensor.masked_fill(mask, value)

It will fill elements of self tensor with value where mask is True.  We should notice: the shape of mask must be broadcastable with the shape of the underlying tensor.

How to use Tensor.masked_fill()?

Here are some examples.

import torch
a=torch.tensor([[[5,5,5,5], [6,6,6,6], [7,7,7,7]], [[1,1,1,1],[2,2,2,2],[3,3,3,3]]])
print(a)
print(a.size())
print("#############################################")
mask = torch.ByteTensor([[[1],[1],[0]],[[0],[1],[1]]])
print(mask.size())
b = a.masked_fill(mask, value=torch.tensor(100))
print(b)
print(b.size())

In this example code, tensor a is:

tensor([[[5, 5, 5, 5],
         [6, 6, 6, 6],
         [7, 7, 7, 7]],

        [[1, 1, 1, 1],
         [2, 2, 2, 2],
         [3, 3, 3, 3]]])
torch.Size([2, 3, 4])

We can find the shape of tensor a is [2, 3, 4]

As to mask, the size of it is:

torch.Size([2, 3, 1])

It means the shape of tensor mask is [2, 3, 1]

As to code:

b = a.masked_fill(mask, value=torch.tensor(100))

We will see tensor b is:

tensor([[[100, 100, 100, 100],
         [100, 100, 100, 100],
         [  7,   7,   7,   7]],

        [[  1,   1,   1,   1],
         [100, 100, 100, 100],
         [100, 100, 100, 100]]])

The size of tensor b is: torch.Size([2, 3, 4])

From this example we can see: the shape of tensor a and mask are not the same.

Moreover, look at this example code:

import torch
a=torch.tensor([[[5,5,5,5], [6,6,6,6], [7,7,7,7]], [[1,1,1,1],[2,2,2,2],[3,3,3,3]]])
print(a)
print(a.size())
print("#############################################")
mask = torch.ByteTensor([[[0]],[[1]]])
print(mask.size())
b = a.masked_fill(mask, value=torch.tensor(-100))
print(b)
print(b.size())

We can find:

Tensor a is: [2, 3, 4]

Tensor mask is: [2, 1, 1]

Run this code, we will see:

ensor([[[5, 5, 5, 5],
         [6, 6, 6, 6],
         [7, 7, 7, 7]],

        [[1, 1, 1, 1],
         [2, 2, 2, 2],
         [3, 3, 3, 3]]])
torch.Size([2, 3, 4])
#############################################
torch.Size([2, 1, 1])
tensor([[[   5,    5,    5,    5],
         [   6,    6,    6,    6],
         [   7,    7,    7,    7]],

        [[-100, -100, -100, -100],
         [-100, -100, -100, -100],
         [-100, -100, -100, -100]]])
torch.Size([2, 3, 4])