Understand torch.bmm() with Examples – PyTorch Tutorial

By | October 18, 2024

In this tutorial, we will use some examples to show you how to use torch.bmm() function in PyTorch.

torch.bmm()

It is defined as:

torch.bmm(input, mat2, *, out=None)

It will implement a batch matrix-matrix product of matrices between input and mat2.

input: b * n * m

mat2: b * m * p

torch.bmm() will return a tensor with b * n * p shape.

For example:

import torch

x = torch.randn(4, 5, 6)
y = torch.randn(4, 6, 3)

z = torch.bmm(x, y)
print(z.size())
print(z)

Run this code, we will see:

torch.Size([4, 5, 3])
tensor([[[ 2.3357e+00, -2.3579e-01,  3.3147e+00],
         [-4.1749e+00,  3.9280e+00, -1.8552e+00],
         [ 7.4033e-01,  1.0861e+00, -2.1744e-02],
         [ 1.6855e+00, -2.6260e+00,  6.0316e+00],
         [ 1.3546e+00, -7.9034e-01,  7.1132e-01]],

        ...

        [[ 4.6971e+00,  4.0988e-01, -9.6455e-01],
         [-1.8950e+00, -1.1270e+00,  2.1010e+00],
         [ 8.4094e-01, -1.7533e+00, -1.2964e-01],
         [-1.6315e+00, -4.2563e-02,  2.7969e+00],
         [-6.2584e-01,  2.0024e+00,  7.3514e-01]]])

In this example: b = 4, n = 5, m = 6, p = 3, tensor z is 4 * 5 * 3

How about the batch size is different?

For example:

import torch

x = torch.randn(4, 5, 6)
y = torch.randn(3, 6, 3)

z = torch.bmm(x, y)
print(z.size())
print(z)

Run this code, we will get this error:

RuntimeError: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false.

It means we should keep the bath size is the same.

Leave a Reply