In this tutorial, we will use some examples to show you how to use torch.bmm() function in PyTorch.
torch.bmm()
It is defined as:
torch.bmm(input, mat2, *, out=None)
It will implement a batch matrix-matrix product of matrices between input and mat2.
input: b * n * m
mat2: b * m * p
torch.bmm() will return a tensor with b * n * p shape.
For example:
import torch x = torch.randn(4, 5, 6) y = torch.randn(4, 6, 3) z = torch.bmm(x, y) print(z.size()) print(z)
Run this code, we will see:
torch.Size([4, 5, 3]) tensor([[[ 2.3357e+00, -2.3579e-01, 3.3147e+00], [-4.1749e+00, 3.9280e+00, -1.8552e+00], [ 7.4033e-01, 1.0861e+00, -2.1744e-02], [ 1.6855e+00, -2.6260e+00, 6.0316e+00], [ 1.3546e+00, -7.9034e-01, 7.1132e-01]], ... [[ 4.6971e+00, 4.0988e-01, -9.6455e-01], [-1.8950e+00, -1.1270e+00, 2.1010e+00], [ 8.4094e-01, -1.7533e+00, -1.2964e-01], [-1.6315e+00, -4.2563e-02, 2.7969e+00], [-6.2584e-01, 2.0024e+00, 7.3514e-01]]])
In this example: b = 4, n = 5, m = 6, p = 3, tensor z is 4 * 5 * 3
How about the batch size is different?
For example:
import torch x = torch.randn(4, 5, 6) y = torch.randn(3, 6, 3) z = torch.bmm(x, y) print(z.size()) print(z)
Run this code, we will get this error:
RuntimeError: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false.
It means we should keep the bath size is the same.