Understand tf.expand_dims() with Examples – TensorFlow Tutorial

By | March 1, 2021

TensorFlow tf.expand_dims() allows us to add a dimension of 1 for a tensor. In this tutorial, we will use some examples to show you how to use this function.

tf.expand_dims()

tf.expand_dims() is defined as:

tf.expand_dims(
    input,
    axis=None,
    name=None,
    dim=None
)

It allows us to insert a dimension of 1 into a tensor on aixs.

Here are some examples:

# 't' is a tensor of shape [2]
tf.shape(tf.expand_dims(t, 0))  # [1, 2]

It means we will insert a 1 dimension on axis = 0, the shape of t is 1 on axis = 0.

# 't' is a tensor of shape [2]
tf.shape(tf.expand_dims(t, 1))  # [2, 1]

It means we will insert a 1 dimension on axis = 1, the shape of t is 1 on axis = 1.

More examples:

tf.shape(tf.expand_dims(t, -1))  # [2, 1]

# 't2' is a tensor of shape [2, 3, 5]
tf.shape(tf.expand_dims(t2, 0))  # [1, 2, 3, 5]
tf.shape(tf.expand_dims(t2, 2))  # [2, 3, 1, 5]
tf.shape(tf.expand_dims(t2, 3))  # [2, 3, 5, 1]

Leave a Reply