In this tutorial, we will use some examples to show you how to use pytorch torch.max() function, which can make us get the maximum value of a tensor.
torch.max()
It is defined as:
torch.max(input)
It will return the maximum value of all elements in the input tensor.
However, there are some tips you should notice.
torch.max() without dimension
It will get the maximum value from all elements in a tensor.
For example:
import torch input = torch.tensor([ [1, 2, 100], [4, 3, 3], [1, 200, 3], [4, 5, 6] ], dtype=torch.float32) m = torch.max(input) print(m)
Here we use torch.max() without dimension, it will return the maximum value 200 in input.
tensor(200.)
torch.max() with dimension
It will return a tuple (max, max_indices)
For example:
import torch input = torch.tensor([ [1, 2, 100], [4, 3, 3], [1, 200, 3], [4, 5, 6] ], dtype=torch.float32) max_value, max_indices = torch.max(input, dim = 0) print(input) print(max_value) print(max_indices)
Here we have used torch.max() with dim = 0, we will see:
tensor([[ 1., 2., 100.], [ 4., 3., 3.], [ 1., 200., 3.], [ 4., 5., 6.]]) tensor([ 4., 200., 100.]) tensor([1, 2, 0])
If we set dim = 1, we will see:
max_value, max_indices = torch.max(input, dim = 1) print(input) print(max_value) print(max_indices)
The result is:
tensor([[ 1., 2., 100.], [ 4., 3., 3.], [ 1., 200., 3.], [ 4., 5., 6.]]) tensor([100., 4., 200., 6.]) tensor([2, 0, 1, 2])
We also can use argument keepdim= True, for example:
mport torch input = torch.tensor([[[ [1, 2, 100], [4, 3, 3]], [[1, 200, 3], [4, 5, 6] ]]], dtype=torch.float32) print(input.shape) max_value, max_indices = torch.max(input, dim = 1,keepdim= True) print(input) print(max_value) print(max_indices)
Run this code, we will see:
torch.Size([1, 2, 2, 3]) tensor([[[[ 1., 2., 100.], [ 4., 3., 3.]], [[ 1., 200., 3.], [ 4., 5., 6.]]]]) tensor([[[[ 1., 200., 100.], [ 4., 5., 6.]]]]) tensor([[[[0, 1, 0], [0, 1, 1]]]])