Pytorch tensor.repeat() function can repeat a tensor along the specified dimensions. In this tutorial, we will use some examples to show you how to use it.
Syntax
tensor.repeat() is defined as:
Tensor.repeat(*sizes)
size: it can be an integer or torch.size. The number of times to repeat this tensor along each dimension
How to use tensor.repeat() ?
Here is an example:
import torch x = torch.tensor([1, 2, 3]) y = x.repeat(2,) print(y, y.shape) print(x)
Run this code, we will see:
tensor([1, 2, 3, 1, 2, 3]) torch.Size([6]) tensor([1, 2, 3])
We can find the tensor x is not modified and x.repeat() will return a new tensor.
x = torch.tensor([[1, 2, 3],[4,5,6]]) y = x.repeat(4, 2) print(y, y.shape) print(x)
Run this code, we will see:
tensor([[1, 2, 3, 1, 2, 3], [4, 5, 6, 4, 5, 6], [1, 2, 3, 1, 2, 3], [4, 5, 6, 4, 5, 6], [1, 2, 3, 1, 2, 3], [4, 5, 6, 4, 5, 6], [1, 2, 3, 1, 2, 3], [4, 5, 6, 4, 5, 6]]) torch.Size([8, 6]) tensor([[1, 2, 3], [4, 5, 6]]) Process finished with exit code 0
x.repeat(4, 2) represents data in x dim = 0 will repeat 4 times and data in x dim = 1 will repeat 2 times.
Finally, we will get a new tensor with shape [8, 6]