Understand TensorFlow tf.stack(): Packing a List of Tensors Along The Axis Dimension – TensorFlow Tutorial

By | June 13, 2020

TensorFlow tf.stack() function can pack a list of tensors to a new tensor, which is very useful if you plan bind some tensors. In this tutorial, we will introduce you how to use this function with some examples.

Syntax

tf.stack(
    values,
    axis=0,
    name='stack'
)

Packing a list of tensors along the axis to a new tensor.

Parameters

values: a list of tensors you plan to stack, the rank of each tensor may be R.

axis: the axis dimension you want to stack values, which determines the way to stack. The value of it is in [-(R+1), R+1)

When you are using tf.stack() function, there are two poits you should notice:

1. If the rank of each tensor in values is R, the new return tensor will be R+1.

2. If the len(values) = N, the shape of each tensor in values is (A, B, C), the shape of new return tensor will be:

Axis Shape of return
axis = 0 (N, A, B, C)
axis = 1 (A, N, B, C)
axis = 2 (A, B, N, C)
axis = 3 (A, B, C, N)
axis = -1 same to axis = 3
axis = -2 same to axis = 2
axis = -3 same to axis = 1
axis = -4 same to axis = 0

We will use some examples to show you how to use tf.stack() function.

Stack two tensors to create a new tensor

We create two 3 * 4 tensors first.

import tensorflow as tf
import numpy as np

# the shape of x1 and x2 is (3, 4)
x1 = tf.Variable(np.array([[2, 2, 3, 4], [1, 5, 3, 2] ,[6, 7, 2, 1]], dtype = np.float32), name = 'x1')
x2 = tf.Variable(np.array([[3, 3, 3, 5],[6, 3, 2, 8], [4, 1, 6, 8]], dtype = np.float32), name = 'x2')

The length of values is 2, which means N = 2. The shape of x1 and x2 is (3, 4), which menas A = 3, B = 4.

Stack x1 and x2 along axis = 0

If axis = 0, the shape of output tensor will be (N, A, B) = (2, 3, 4)

# the shape of output is (2, 3, 4)
x3 = tf.stack([x1, x2], axis = 0)

Output x3, we will get:

[[[2. 2. 3. 4.]
  [1. 5. 3. 2.]
  [6. 7. 2. 1.]]

 [[3. 3. 3. 5.]
  [6. 3. 2. 8.]
  [4. 1. 6. 8.]]]

This result is easy to understand.

Stack x1 and x2 along axis = 1

The shape of output tensor x2 will be (A, N, B) = (3, 2, 4)

# the shape of output is (3, 2, 4)
x3 = tf.stack([x1, x2], axis = 1)

The result will be:

[[[2. 2. 3. 4.]
  [3. 3. 3. 5.]]

 [[1. 5. 3. 2.]
  [6. 3. 2. 8.]]

 [[6. 7. 2. 1.]
  [4. 1. 6. 8.]]]

Stack x1 and x2 along axis = 2

The shape of x3 will be (A, B, N) =  (3, 4, 2)

x3 = tf.stack([x1, x2], axis = 2)

The stacked result is:

[[[2. 3.]
  [2. 3.]
  [3. 3.]
  [4. 5.]]

 [[1. 6.]
  [5. 3.]
  [3. 2.]
  [2. 8.]]

 [[6. 4.]
  [7. 1.]
  [2. 6.]
  [1. 8.]]]

Leave a Reply