Understand tf.concat(): Concatenates Tensors for Beginners – TensorFlow Tutorial

By | December 6, 2019

TensorFlow tf.concat() function is widely used in deep learning, especially when you are creating bilstm with tensorflow. In this tutorial, we will write some examples to help you understand it.

Syntax

tf.concat(
    values,
    axis,
    name='concat'
)

Concatenates tensors along one dimension.

Parameter explained

values: a list of tensor you plan to concatenate

axis: the dimension you plan to concatenate tensors

We will use some examples to show you how to use this function.

Create two 2 * 4 tensors

import tensorflow as tf
import numpy as np

x1 = tf.Variable(np.array([[1, 2, 3, 4],[5, 6, 7, 8]]), dtype = tf.float32)
x2 = tf.Variable(np.array([[8, 7, 6, 5],[4, 3, 2, 1]]), dtype = tf.float32)

In code above, we have created two tensors, both of them are 2 * 4

Concatenate tensors on axis = 0

x = tf.concat([x1, x2], axis = 0)

We will find we should use a list to contain x1 and x2 before concatenating.

The concatenated tensor is:

[array([[1., 2., 3., 4.],
       [5., 6., 7., 8.],
       [8., 7., 6., 5.],
       [4., 3., 2., 1.]], dtype=float32)]

Concatenate tensors on axis = 1

x = tf.concat([x1, x2], axis = 1)

The result is:

[array([[1., 2., 3., 4., 8., 7., 6., 5.],
       [5., 6., 7., 8., 4., 3., 2., 1.]], dtype=float32)]

However, you also be sure:

All tensors in values must have the same data type, otherwise you will get an error: TypeError: Tensors in list passed to ‘values’ of ‘ConcatV2’ Op have types [int32, float32] that don’t all match.

To fix this error, you can view this tutorial.

Fix Tensors in list passed to ‘values’ of ‘ConcatV2’ Op have types [int32, float32] that don’t all match Error

Leave a Reply