tf.map_fn(): Processing Multiple Input and Output Tensors – TensorFlow Tutorial

By | January 25, 2022

TensorFlow tf.map_fn() method can allow us to use a function to process each element in a tensor on axis = 0 and return a tensor. Here is a tutorial:

Understand TensorFlow tf.map_fn(): A Beginner Guide – TensorFlow Tutorial

However, this function also can process multiple input tensors and return multiple tensors. In this tutorial, we will discuss this topic.

How to process multiple input tensors in tf.map_fn()

In order to make tf.map_fun() to process multiple input tensors, we should pack them in a tuple.

For example:

import tensorflow as tf
import numpy as np
data = np.array([[1,2], [4,5]], dtype= np.float)
v1 = tf.convert_to_tensor(data, dtype = tf.float32)

data = np.array([[2,2], [4,4]], dtype= np.float)
v2 = tf.convert_to_tensor(data, dtype = tf.float32)

def xa(x):
    #x[0]: v1
    print(x[0])
    print(x[1])
    #x[1]: v2
    return x[0]+x[1]
t = tf.map_fn(xa, (v1, v2), dtype = tf.float32)

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    np.set_printoptions(precision=4, suppress=True)
    _w = sess.run(t)
    print(_w)

In this example, we will pass two tensors into tf.map_fn(). They are v1 and v2.

We will pack them to (v1, v2) and process each element in xa() function.

As to xa() function, we will use x[0] and x[1] to get each element of v1 and v2. It means:

(v1, v2)
x[0], x[1]

Run this code, we will get:

Tensor("map/while/TensorArrayReadV3:0", shape=(2,), dtype=float32)
Tensor("map/while/TensorArrayReadV3_1:0", shape=(2,), dtype=float32)
[[3. 4.]
 [8. 9.]]

However, if you get ValueError: The two structures don’t have the same nested structure. You can read this tutorial to fix:

Fix tf.map_fn() ValueError: The two structures don’t have the same nested structure – TensorFlow Tutorial

How to return multiple tensors in tf.map_fn()

tf.map_fn() also can return multiple tensors. For example:

import tensorflow as tf
import numpy as np
data = np.array([[1,2], [4,5]], dtype= np.float)
v1 = tf.convert_to_tensor(data, dtype = tf.float32)

data = np.array([[2,2], [4,4]], dtype= np.float)
v2 = tf.convert_to_tensor(data, dtype = tf.float32)

def xa(x):
    #x[0]: v1
    print(x[0])
    print(x[1])
    #x[1]: v2
    return x[0]+x[1],tf.reduce_sum(x[0]),x[0]*x[1]
t1, t2, t3 = tf.map_fn(xa, (v1, v2), dtype = (tf.float32, tf.float32, tf.float32))

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    np.set_printoptions(precision=4, suppress=True)
    _w = sess.run([t1, t2, t3])
    print(_w)

In this example, we will return three tensors: t1, t2, t3. We should set dtype = (tf.float32, tf.float32, tf.float32) for each output tensor.

Run this code, we will get:

[array([[3., 4.],
       [8., 9.]], dtype=float32), array([3., 9.], dtype=float32), array([[ 2.,  4.],
       [16., 20.]], dtype=float32)]

Leave a Reply