Fix tf.map_fn() ValueError: The two structures don’t have the same nested structure – TensorFlow Tutorial

By | January 24, 2022

When we are using tf.map_fn() function, we may get this error: ValueError: The two structures don’t have the same nested structure. In this tutorial, we will introduce you how to fix it.

tf.map_fn()

In order to understand how to use tf.map_fn(), you can read this tutorial:

Understand TensorFlow tf.map_fn(): A Beginner Guide – TensorFlow Tutorial

Look at example code below:

import tensorflow as tf
import numpy as np
input_ids = tf.placeholder(tf.float32, [2,2], name="input_ids")  #
w = tf.Variable(tf.glorot_uniform_initializer()([2, 2]), name = "w")

def xa(x):
    return x[0]+x[1]
t = tf.map_fn(xa, (input_ids, w))

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    np.set_printoptions(precision=4, suppress=True)
    data = np.array([[1,2], [4,5]], dtype= np.float)
    feed_dict = {
        input_ids: data
    }
    _w = sess.run([t], feed_dict = feed_dict)
    print(_w)

Run this code, you will see this error: ValueError: The two structures don’t have the same nested structure.

ValueError: The two structures don't have the same nested structure.

How to fix this ValueError?

You should set data type for tf.map_fn().

For example:

t = tf.map_fn(xa, (input_ids, w), dtype = tf.float32)

Then, run this example code again. You will see this error is fixed.

Fix tf.map_fn() ValueError: The two structures don't have the same nested structure - TensorFlow Tutorial

This example will return a data from xa function. However, if you plan to return more data, how to fix?

For example:

def xa(x):
    return x[0]+x[1], x[0]*x[1], x[0]+2*x[1]

In this example, xa() function will return 3 values.

In order to fix this error, you should set tf.map_fn() as follows:

t = tf.map_fn(xa, (input_ids, w), dtype = (tf.float32, tf.float32, tf.float32))

You should set 3 data types for dtype parameter. These 3 data types reflect the data type returned by xa() function.

Leave a Reply