In this tutorial, we will introduce the way to convert a tensor to a tensorarray object in tensorflow, which is very useful when you are bulding a custom lstm or bilstm.
Create a tensor
We create a tensor first.
import tensorflow as tf import numpy as np x = tf.Variable(np.array(range(36)), dtype = np.float32, name = 'x') x = tf.reshape(x, [3, 3, 4])
The tensor x will be:
[[[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]] [[12. 13. 14. 15.] [16. 17. 18. 19.] [20. 21. 22. 23.]] [[24. 25. 26. 27.] [28. 29. 30. 31.] [32. 33. 34. 35.]]]
Create a tensorarray to save tensors
We will create a tensorarray object gen_o with size = 3
gen_o = tf.TensorArray(dtype=tf.float32, size=3, dynamic_size=False, infer_shape=True)
To understand how to create or use tensorarray, you can read:
Understand TensorFlow TensorArray: A Beginner Tutorial
Convert a tensor to tensorarray
We will use two methods to convert a tensor to tensorarray.
Method 1: use tensorarray.unstack()
Here is an example:
gen_o = gen_o.unstack(x) z0 = gen_o.read(0) z1 = gen_o.read(1) z2 = gen_o.read(2)
Then print z0, z1 and z2.
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(x)) print("z0=") print(sess.run(z0)) print("z1=") print(sess.run(z1)) print("z2=") print(sess.run(z2))
You will get the result.
z0= [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]] z1= [[12. 13. 14. 15.] [16. 17. 18. 19.] [20. 21. 22. 23.]] z2= [[24. 25. 26. 27.] [28. 29. 30. 31.] [32. 33. 34. 35.]]
We can find the shape of z0, z1 and z2 is (3, 4)
Method 2: use tensorarray.split() function
Here is an example:
gen_o = gen_o.split(x, lengths = [1,1,1]) z0 = gen_o.read(0) z1 = gen_o.read(1) z2 = gen_o.read(2)
You should notice lengths = [1,1,1], which means there are only len(lengths) tensors in tensorarray object.
Print z0, z1 and z2, you will get the result:
z0= [[[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]]] z1= [[[12. 13. 14. 15.] [16. 17. 18. 19.] [20. 21. 22. 23.]]] z2= [[[24. 25. 26. 27.] [28. 29. 30. 31.] [32. 33. 34. 35.]]]
You will find the shape of z0, z1 and z2 is (1, 3, 4), which is different from gen_o.unstack(x).