PyTorch Tensor.masked_fill() is same to Tensor.masked_fill_(mask, value). In this tutorial, we will use some examples to show you how to use it.
Syntax
Tensor.masked_fill() is defined as:
Tensor.masked_fill(mask, value)
It will fill elements of self tensor with value where mask is True. We should notice: the shape of mask must be broadcastable with the shape of the underlying tensor.
How to use Tensor.masked_fill()?
Here are some examples.
import torch a=torch.tensor([[[5,5,5,5], [6,6,6,6], [7,7,7,7]], [[1,1,1,1],[2,2,2,2],[3,3,3,3]]]) print(a) print(a.size()) print("#############################################") mask = torch.ByteTensor([[[1],[1],[0]],[[0],[1],[1]]]) print(mask.size()) b = a.masked_fill(mask, value=torch.tensor(100)) print(b) print(b.size())
In this example code, tensor a is:
tensor([[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7]], [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]]) torch.Size([2, 3, 4])
We can find the shape of tensor a is [2, 3, 4]
As to mask, the size of it is:
torch.Size([2, 3, 1])
It means the shape of tensor mask is [2, 3, 1]
As to code:
b = a.masked_fill(mask, value=torch.tensor(100))
We will see tensor b is:
tensor([[[100, 100, 100, 100], [100, 100, 100, 100], [ 7, 7, 7, 7]], [[ 1, 1, 1, 1], [100, 100, 100, 100], [100, 100, 100, 100]]])
The size of tensor b is: torch.Size([2, 3, 4])
From this example we can see: the shape of tensor a and mask are not the same.
Moreover, look at this example code:
import torch a=torch.tensor([[[5,5,5,5], [6,6,6,6], [7,7,7,7]], [[1,1,1,1],[2,2,2,2],[3,3,3,3]]]) print(a) print(a.size()) print("#############################################") mask = torch.ByteTensor([[[0]],[[1]]]) print(mask.size()) b = a.masked_fill(mask, value=torch.tensor(-100)) print(b) print(b.size())
We can find:
Tensor a is: [2, 3, 4]
Tensor mask is: [2, 1, 1]
Run this code, we will see:
ensor([[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7]], [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]]) torch.Size([2, 3, 4]) ############################################# torch.Size([2, 1, 1]) tensor([[[ 5, 5, 5, 5], [ 6, 6, 6, 6], [ 7, 7, 7, 7]], [[-100, -100, -100, -100], [-100, -100, -100, -100], [-100, -100, -100, -100]]]) torch.Size([2, 3, 4])