Understand TensorArray.gather() Method with Examples – TensorFlow Tutorial

By | December 16, 2020

TensorFlow TensorArray.gather() function allows us to get some tensor elements by their indices.  In this tutorial, we will use some examples to show you how to use this function.

Syntax

gather() function is defined as:

gather(
    indices, name=None
)

Here indices is a python list like, such as [1, 2, 3]

TensorArray.gather() is very similar to tf.nn.embedding_lookup(), you also can refer this tutorial to understand.

Understand tf.nn.embedding_lookup(): Pick Up Elements by Ids

We should notice:

TensorArray.gather() will return a tensor, not a tensorarray.

Here is an example:

We will create a tensorarray and insert some tensors into it.

import numpy as np
import tensorflow as tf

cap_map = tf.convert_to_tensor(np.array([[-1,3,2], [-3,1,3],[2,-7,4],[5,7, 6]], dtype = float), dtype = tf.float32)
gen_o = tf.TensorArray(dtype=tf.float32, size=5,
                                             dynamic_size=False, infer_shape=True)
gen_o=gen_o.write(0,cap_map)
gen_o=gen_o.write(1,cap_map)
gen_o=gen_o.write(2,cap_map)
gen_o=gen_o.write(3,cap_map)
gen_o=gen_o.write(4,cap_map)

In this code, we have created a tensorarray gen_o and inserted five tensors in it.

Then we will pick some tensors.

x = gen_o.gather(indices=[1,2])
print(type(x))
init = tf.initialize_all_variables()                                                                             
with tf.Session() as sess:
    sess.run(init)
    print (sess.run(x))

Run this code, we will find:

<class 'tensorflow.python.framework.ops.Tensor'>
[[[-1.  3.  2.]
  [-3.  1.  3.]
  [ 2. -7.  4.]
  [ 5.  7.  6.]]

 [[-1.  3.  2.]
  [-3.  1.  3.]
  [ 2. -7.  4.]
  [ 5.  7.  6.]]]

x is a tensor, which means we can use tf.matmul(), tf.add(), tf.tanh() to operate x directly.

Leave a Reply