TensorFlow tf.concat() function is widely used in deep learning, especially when you are creating bilstm with tensorflow. In this tutorial, we will write some examples to help you understand it.
Syntax
tf.concat( values, axis, name='concat' )
Concatenates tensors along one dimension.
Parameter explained
values: a list of tensor you plan to concatenate
axis: the dimension you plan to concatenate tensors
We will use some examples to show you how to use this function.
Create two 2 * 4 tensors
import tensorflow as tf import numpy as np x1 = tf.Variable(np.array([[1, 2, 3, 4],[5, 6, 7, 8]]), dtype = tf.float32) x2 = tf.Variable(np.array([[8, 7, 6, 5],[4, 3, 2, 1]]), dtype = tf.float32)
In code above, we have created two tensors, both of them are 2 * 4
Concatenate tensors on axis = 0
x = tf.concat([x1, x2], axis = 0)
We will find we should use a list to contain x1 and x2 before concatenating.
The concatenated tensor is:
[array([[1., 2., 3., 4.], [5., 6., 7., 8.], [8., 7., 6., 5.], [4., 3., 2., 1.]], dtype=float32)]
Concatenate tensors on axis = 1
x = tf.concat([x1, x2], axis = 1)
The result is:
[array([[1., 2., 3., 4., 8., 7., 6., 5.], [5., 6., 7., 8., 4., 3., 2., 1.]], dtype=float32)]
However, you also be sure:
All tensors in values must have the same data type, otherwise you will get an error: TypeError: Tensors in list passed to ‘values’ of ‘ConcatV2’ Op have types [int32, float32] that don’t all match.
To fix this error, you can view this tutorial.