TensorFlow tf.tensordot() is a powerful function to multiply tensors. It allows us to multiply different rank tensors. In this tutorial, we will use some examples to show you how to use this function.
tf.tensordot()
It is defined as:
tf.tensordot( a, b, axes, name=None )
Here we should notice:
- axes can be an integer, a list or tuple list.
- axes can be devided to [a_axes, b_axes], which means axes = [a_axes, b_axes].
- We use a_axes to select the value in tensor a, use b_axes to select the value in tensor b.
For example:
y = tf.tensordot(a, b, 1) y = tf.tensordot(a, b, [1,1]) y = tf.tensordot(a, b, [(1,2), (2, 10])
However, if axes is an integer, it shoud be bigger than 0.
How to multiply tensors in tf.tensordot()?
We will use some examples to explain.
If axes is an integer
Look at this example:
import tensorflow as tf a = tf.ones(shape=[5,4,2,3]) b = tf.ones(shape=[3,2,6]) c = tf.tensordot(a,b, axes=1) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) #print(sess.run(c)) print(sess.run(tf.shape(c)))
Here axes = 1, it is an integer.
How to get the shape of result c?
Look at tf.tensordot() source code:
https://github.com/tensorflow/tensorflow/blob/23c218785eac5bfe737eec4f8081fd0ef8e0684d/tensorflow/python/ops/math_ops.py#L2899
The shape of tensor c is computed based on a_axes and b_axes.
Because axes = 1, which is an integer. From the source code, we can find the shape of tensor c is:
product = array_ops.reshape( ab_matmul, array_ops.concat([a_free_dims, b_free_dims], 0), name=name)
How to compute a_free_dims and b_free_dims?
Here axes = 1, it is an integer.
Look at the source code:
a_axes, b_axes = _tensordot_axes(a, axes)
Here
In this example:
The shape of a is [5, 4, 2, 3]
rank_a = 4
a_axes = range(4-1,4) = [3]
b_axes = range(1) = [0]
In order to get a_free_dims and b_free_dims,
Look at source code:
In this example:
The shape of a is [5, 4, 2, 3]
rank_a = 4, rank_b = 3
a_axes = [3]
a_free = setdiff1d([0, 1, 2, 3],[3]) = [0, 1, 2]
a_free_dims = [5, 4, 2]
b_free = setdiff1d([0, 1, 2], [0]) = [1,2]
b_free_dims = [2, 6]
The shape of c = concat([5, 4, 2], [2, 6]) = [5, 4, 2, 2, 6]
If axes = [a_axes, b_axes]
For example:
axes = [1, 1]
axes will be converted to:
axes = [[1], [1]]
Here is the source code:
a_axes = axes[0] b_axes = axes[1] if isinstance(a_axes, compat.integral_types) and \ isinstance(b_axes, compat.integral_types): a_axes = [a_axes] b_axes = [b_axes]
Here a_axes and b_axes is list.
We must make sure len(a_axes) = len(a_axes).
Here is the source code:
How to compute tensor c when axes is list?
Different from axes is an integer. when axes is list. The shape of tensor c is:
a_free_dims + b_free_dims
In order to compute c, there are three main steps.
Look at the source code:
Step 1:
tensor a shape: [5, 4, 2, 3]
a_axes = [1]
a_free = [i for i in xrange(4) if i not in [1]] = [0, 2, 3]
a_free_dims = [5, 2, 3]
Similar to a_free_dims, the b_free_dims = [3, 6]
Step 2:
The tensor a is converted to new_shape (5*2*3) * 4 = 30*4
The tensor b is converted to new_shape 2* (3*6)* 2 = 2* 18
Here 2≠ 4, it will be wrong.
For example:
import tensorflow as tf a = tf.ones(shape=[5,4,2,3]) r = tf.rank(a) b = tf.ones(shape=[3,2,6]) c = tf.tensordot(a,b, axes=[1,1]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(r)) print(sess.run(tf.shape(c)))
Run this code, you will get an error:
ValueError: Dimensions must be equal, but are 4 and 2 for ‘Tensordot/MatMul’ (op: ‘MatMul’) with input shapes: [30,4], [2,18].
If b =
b = tf.ones(shape=[3,4,6])
We will get a resultable tensor c, the shape fo it is 30*18
Step 3:
We shoud reshape tensor c.
The shape fo tensor c is:
a_free_dims + b_free_dims
In this example, it is:
[5, 2, 3] + [3, 6] = [5, 2, 3, 3, 6]
Look at more exmaple:
Example 1:
a = tf.ones(shape=[5,4,2,3]) b = tf.ones(shape=[3,4,6]) c = tf.tensordot(a,b, axes=[[0, 2],[1, 2]])
Here a_axes = [0, 2], tensor a can be converted to (4*3) * (5*2) = 12 * 10 shape.
b_axes = [1, 2], tensor b can be converted to (4*6) * 3 = 24 *3.
10 ≠ 24. It will report an error:
ValueError: Dimensions must be equal, but are 10 and 24 for ‘Tensordot/MatMul’ (op: ‘MatMul’) with input shapes: [12,10], [24,3].
Example 2:
a = tf.ones(shape=[5,4,2,3]) b = tf.ones(shape=[3,5,2]) c = tf.tensordot(a,b, axes=[[0, 2],[1, 2]])
Here a is converted to 12*10, b is converted to 10*3. Then, we will get a tensor c with the shape 12* 3.
a free dims is [4, 3], b free dims is [3]
Tensor c will be reshaped to [4, 3, 3].
tried to copy paste some of your code into collab notebooks, and the code doesnt work. even the pieces of it.
whats going on, i dont get it
Hi, all codes we have wrote have been tested, you can check your tensorflow version. Our version is 1.10.