TensorFlow tf.map_fn() method can allow us to call a function for each element in a tensor on axis = 0. In this tutorial, we will use some simple examples to help you understand and use this function.
Syntax
tf.map_fn( fn, elems, dtype=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, name=None )
Parameters explained
fn: the function which be called. Its parameter is the each element in elems on axis = 0.
elems: a tensor, the elements will be passed into fn on axis = 0.
back_prop: support for back propagation or not, which is very helpful when building deep learning model.
We will use some simple examples to show how to use this function.
Create two tensors
import tensorflow as tf import numpy as np x = tf.Variable(np.array([[1, 2, 2, 1],[2, 1, 3, 4], [4, 3, 1, 1]]), dtype = tf.int32) z = tf.Variable(np.array([1, 2, 2, 1]), dtype = tf.int32)
We have created two tensors, x and z. We will use each element in x on axi = 0 to multiply z.
Use tf.map_fn
def integrate(ix): print(ix) x1 = tf.multiply(ix, z) return x1 xx = tf.map_fn(integrate, x) print(xx)
Here we have created a function named integrate, which will use ix to multiply z. ix is the element in x on axis = 0.
Print the result
init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) print(sess.run(xx))
Run this code, we will get the result:
Tensor("map/while/TensorArrayReadV3:0", shape=(4,), dtype=int32) Tensor("map/TensorArrayStack/TensorArrayGatherV3:0", shape=(3, 4), dtype=int32) [[1 4 4 1] [2 2 6 4] [4 6 2 1]]
From the result, we can find:
1.The shape of x is (3, 4). However, the shape of ix is (4, ), which means you have to reshape the shape of ix when operating it.
2.The shape of result xx is (3, 4) , the shape of it you also to notice.
The operation is below:
If you want to use lambda, you also can do like this:
xx = tf.map_fn(lambda ix:tf.multiply(ix, z), x)
or
xx = tf.map_fn(lambda ix:integrate(ix), x)
To understand lambda, you can read:
Understand Python Lambda Function for Beginners – Python Tutorial