Understand Tensor Axis and Shape with Examples: A Beginner Guide – TensorFlow Tutorial

By | November 15, 2019

Tensor axis and shape are very important when we are computing in tensorflow. What are relationship between them? In this tutorial, we will discuss.

What is tensor axis?

The tensor axis describes the demension of a tensor.

For example, if a tensor is a 4-dimension, which means its axis = [0, 1, 2, 3]

Set dimension of a tensor is n.  The axis of a tensor will be axis = [0, 1, 2, …, n-1]

What is tensor shape?

Tensor shape indicates the element size on each axis or demension.

For example, if the shape of a tensor is [ 2, 3, 4], which means this tensor axis = [0, 1, 2] and there are 2 elements on axis = 0, 3 elements on axis = 2 and 4 elements on axis = 3.

The relation of tensor axis and shape.

Suppose a tensor shape is: shape = [x, y, z, …], its axis = [0, 1, 2, …]

tensor axis and shape relation

Understand the relation between axis and shape, we will can understand aixs parameter in some tensorflow functions.

For example, tf.reduce_sum() will return different result on different axis.

Here is an example.

Create a 3*4 tensor

import tensorflow as tf
import numpy as np

#w=shape(3,4)
w = tf.Variable(np.array([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]]), dtype = tf.float32)

Sum this tensor on different axis

sum on axis = 0

sum_1 = tf.reduce_sum(w, axis = 0)

The result will be:

array([ 15.,  18.,  21.,  24.], dtype=float32)

Why? Because shape = [3, 4], there are 3 elements on axis = 0.

Which sums

[1, 2, 3, 4] + [5, 6, 7, 8] + [9, 10, 11, 12] = [15, 18, 21, 24]

Sum tensor on axis = 1

sum_2 = tf.reduce_sum(w, axis = 1)

The result will be:

array([ 10.,  26.,  42.], dtype=float32)

Because there are 4 elements on axis = 1, which means

sum([1, 2, 3, 4]) = 1 + 2 + 3 + 4 = 10

sum([5, 6, 7, 8]) = 5 + 6 + 7 + 8 = 26

sum([9, 10, 11, 12]) = 9 + 10 + 11 + 12 = 42

Leave a Reply