Understand TensorFlow tf.argmax() and Axis for Beginners – TensorFlow Tutorial

By | December 2, 2019

TensorFlow tf.argmax() function is often used in classification problem. In order to get the predictions, we often use this function. In this tutorial, we will introduce how to use this function correctly for tensorflow beginners.

Syntax

tf.argmax(
    input,
    axis=None,
    name=None,
    dimension=None,
    output_type=tf.int64
)

Returns the index with the largest value across axes of a tensor, which means the return value starts with 0.

Parameters explained

input: the tensor we will get the index of the largest value

axis: get largest value along which axis of input

Here is some examples to show how to use this function.

As to 2 dimension input tensor

import tensorflow as tf
import numpy as np

x = tf.Variable(np.array([[1, 9, 3],[4, 5, 6]]), dtype = tf.float32)

Here we create an input tensor x and the shape of it is 2 * 3, which means the axis of it can be 0 or 1.

Get the index of the largest value on axis = 0

max_index = tf.argmax(x, axis = 0)
init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()

with tf.Session() as sess:
    sess.run([init, init_local])
    print(sess.run([max_index]))

The index result is:

[array([1, 0, 1], dtype=int64)]

Why value is [1, 0, 1]?

To understand it, you can look picture below.

tensorflow tf.argmax() tutorials and examples

When axis = 0, we will compare the value like above, so the result is [1, 0, 1].

Get the index of the largest value on axis = 1

max_index = tf.argmax(x, axis = 1)

The result is:

[array([1, 2], dtype=int64)]

This is very easy understand.

One thought on “Understand TensorFlow tf.argmax() and Axis for Beginners – TensorFlow Tutorial

Leave a Reply