Understand torch.sum() with Examples – PyTorch Tutorial

By | December 29, 2022

It is easy to use torch.sum() function. In this tutorial, we will use some examples to show you how to use it.

Syntax

torch.sum() is defined as:

torch.sum(input, dim, keepdim=False, *, dtype=None)

It will sum a tensor based on dim and return a tensor. We should notice we can not use axis, which is common used in tensorflow.

Here we will use some examples to show you how to use this function.

dim = None

For example:

>>> a = torch.randn(1, 3)
>>> a
tensor([[ 0.1133, -0.9567,  0.2958]])
>>> torch.sum(a)
tensor(-0.5475)

dim = 1

>>> a = torch.randn(4, 4)
>>> a
tensor([[ 0.0569, -0.2475,  0.0737, -0.3429],
        [-0.2993,  0.9138,  0.9337, -1.6864],
        [ 0.1132,  0.7892, -0.1003,  0.5688],
        [ 0.3637, -0.9906, -0.4752, -1.5197]])
>>> torch.sum(a, 1)
tensor([-0.4598, -0.1381,  1.3708, -2.6217])

dim is a tuple

>>> b = torch.arange(4 * 5 * 6).view(4, 5, 6)
>>> torch.sum(b, (2, 1))
tensor([  435.,  1335.,  2235.,  3135.])