Both torch.matmul() and torch.mm() can perform a matrix multiplication. In this tutorial, we will introduce the difference between them.
torch.matmul()
It is defined as:
torch.matmul(input, other, *, out=None)
torch.mm()
It is defined as:
torch.mm(input, mat2, *, out=None)
Here input is a (n×m) tensor, mat2 is a (m×p) tensor, out will be a (n×p) tensor.
We also notice: torch.mm() does not broadcast. For broadcasting matrix products, we can use torch.matmul().
Here we can use an example to show the difference.
import torch import torch.nn.functional as F x = torch.randn(5,200) y = torch.randn(200, 5) print(torch.mm(x, y)) print(torch.matmul(x, y))
Run this code, we will find the results are the same.
tensor([[-27.3500, -4.1209, 16.5225, 3.5113, 23.4709], [ 33.1420, 3.8333, 13.8869, -13.8083, -12.5489], [ 0.1656, 6.9874, -3.2829, 15.0789, -14.5087], [-21.0825, -2.0691, -14.8276, 3.2864, 6.1927], [ -3.1685, 18.8321, 6.1552, 7.9586, 6.6930]]) tensor([[-27.3500, -4.1209, 16.5225, 3.5113, 23.4709], [ 33.1420, 3.8333, 13.8869, -13.8083, -12.5489], [ 0.1656, 6.9874, -3.2829, 15.0789, -14.5087], [-21.0825, -2.0691, -14.8276, 3.2864, 6.1927], [ -3.1685, 18.8321, 6.1552, 7.9586, 6.6930]])