Understand PyTorch Tensor.repeat() with Examples – PyTorch Tutorial

By | March 27, 2023

Pytorch tensor.repeat() function can repeat a tensor along the specified dimensions. In this tutorial, we will use some examples to show you how to use it.

Syntax

tensor.repeat() is defined as:

Tensor.repeat(*sizes)

size: it can be an integer or torch.size. The number of times to repeat this tensor along each dimension

How to use tensor.repeat() ?

Here is an example:

import torch

x = torch.tensor([1, 2, 3])
y = x.repeat(2,)

print(y, y.shape)
print(x)

Run this code, we will see:

tensor([1, 2, 3, 1, 2, 3]) torch.Size([6])
tensor([1, 2, 3])

We can find the tensor x is not modified and x.repeat() will return a new tensor.

x = torch.tensor([[1, 2, 3],[4,5,6]])
y = x.repeat(4, 2)

print(y, y.shape)
print(x)

Run this code, we will see:

tensor([[1, 2, 3, 1, 2, 3],
        [4, 5, 6, 4, 5, 6],
        [1, 2, 3, 1, 2, 3],
        [4, 5, 6, 4, 5, 6],
        [1, 2, 3, 1, 2, 3],
        [4, 5, 6, 4, 5, 6],
        [1, 2, 3, 1, 2, 3],
        [4, 5, 6, 4, 5, 6]]) torch.Size([8, 6])
tensor([[1, 2, 3],
        [4, 5, 6]])

Process finished with exit code 0

x.repeat(4, 2) represents data in x dim = 0 will repeat 4 times and data in x dim = 1 will repeat 2 times.

Finally, we will get a new tensor with shape [8, 6]