In this tutorial, we will use some examples to show you how to use torch.sort() function in pytorch correctly.
torch.sort()
It is defined as:
torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)
This function will sort the elements of the input tensor along a given dimension in ascending order by value.
Input parameters
input: a tensor
dim: elements of the input will be sorted along it.
stable: preserve the order of equivalent elements in input, it will make the sorted result more stable.
Output
This function will return (Tensor, LongTensor)
Tensor: the sorted tensor
LongTensor: the index of Tensor elements in input
How to use?
Here we will use some examples to show you how to use this function.
For example:
import torch x = torch.randn(3, 4) sorted_tensor, tensor_indices = torch.sort(x) print("x", x) print(sorted_tensor) print(tensor_indices)
Run this code, we will get:
x tensor([[-0.2957, -0.4860, -1.7153, -0.5024], [-0.4195, -0.1657, -0.9259, -0.6146], [ 1.0523, -1.3726, 0.8340, -2.0896]]) tensor([[-1.7153, -0.5024, -0.4860, -0.2957], [-0.9259, -0.6146, -0.4195, -0.1657], [-2.0896, -1.3726, 0.8340, 1.0523]]) tensor([[2, 3, 1, 0], [2, 3, 0, 1], [3, 1, 2, 0]])
You can understand this result as follows:
Here we will sort elements in x along aixs = -1. Then, we also will save its index in tensor_indices.
If dim =0, how about the result?
For example:
x = torch.randn(3, 4) sorted_tensor, tensor_indices = torch.sort(x, dim = 0) print("x", x) print(sorted_tensor) print(tensor_indices)
Run this code, we will see:
x tensor([[ 0.8447, 0.0204, 0.1753, 0.7887], [-0.2143, -0.0590, -0.4153, -0.4881], [-0.0828, -0.1560, -0.6996, -0.3439]]) tensor([[-0.2143, -0.1560, -0.6996, -0.4881], [-0.0828, -0.0590, -0.4153, -0.3439], [ 0.8447, 0.0204, 0.1753, 0.7887]]) tensor([[1, 2, 2, 1], [2, 1, 1, 2], [0, 0, 0, 0]])
It is:
How about stable = True?
It is useful when there are some same elements along a dimension.
For example:
import torch x = torch.tensor([0, 1] * 20) print("x", x) sorted_tensor, tensor_indices = torch.sort(x, stable= False) print(sorted_tensor) print(tensor_indices) sorted_tensor_stable, tensor_indices_stable = torch.sort(x, stable= True) print(sorted_tensor_stable) print(tensor_indices_stable)
Run this code, we will get:
x tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) tensor([ 0, 20, 2, 22, 4, 24, 6, 26, 8, 28, 10, 30, 12, 32, 14, 34, 16, 36, 18, 38, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 25, 31, 23, 33, 27, 35, 21, 37, 29, 39]) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39])
We can find sorted_tensor and sorted_tensor_stable are the same. However, tensor_indices and tensor_indices_stable are different.
We can find when stable = True is much better when there are some same elements in a tensor dimension.