TensorFlow tf.argmax() function is often used in classification problem. In order to get the predictions, we often use this function. In this tutorial, we will introduce how to use this function correctly for tensorflow beginners.
Syntax
tf.argmax( input, axis=None, name=None, dimension=None, output_type=tf.int64 )
Returns the index with the largest value across axes of a tensor, which means the return value starts with 0.
Parameters explained
input: the tensor we will get the index of the largest value
axis: get largest value along which axis of input
Here is some examples to show how to use this function.
As to 2 dimension input tensor
import tensorflow as tf import numpy as np x = tf.Variable(np.array([[1, 9, 3],[4, 5, 6]]), dtype = tf.float32)
Here we create an input tensor x and the shape of it is 2 * 3, which means the axis of it can be 0 or 1.
Get the index of the largest value on axis = 0
max_index = tf.argmax(x, axis = 0) init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) print(sess.run([max_index]))
The index result is:
[array([1, 0, 1], dtype=int64)]
Why value is [1, 0, 1]?
To understand it, you can look picture below.
When axis = 0, we will compare the value like above, so the result is [1, 0, 1].
Get the index of the largest value on axis = 1
max_index = tf.argmax(x, axis = 1)
The result is:
[array([1, 2], dtype=int64)]
This is very easy understand.