tf.reverse_sequence() is an important function for building a custom bilstm model, which is different from tf.reverse().
To understand tf.reverse(), you can read:
Understand TensorFlow tf.reverse(): Reverse a Tensor Based on Axis
The difference between tf.reverse_sequence() and tf.reverse()
Both of them can reverse a tensor based on axis. However, tf.reverse_sequence() can allow us to reverse a tensor by length.
For example:
As to sentence: I love this food and service, there are 6 words in this sentence, if you plan to use bilstm model to process this sentence, you may set the time_step is 8, which means there are two invalid words in this sentence.
The sentence will be viewed as: I love this food and service – –
As to tensorflow tf.reverse_sequence() and tf.reverse(), the reversed sentence will be:
How to use tensorflow tf.reverse_sequence()? In this tutorial, we will discuss this topic.
Syntax
tf.reverse_sequence( input, seq_lengths, seq_axis=None, batch_axis=None, name=None, seq_dim=None, batch_dim=None )
tf.reverse_sequence() allows us to reverse a tensor by seq_lengths based on seq_axis.
Parameters explained
input: the tensor will be revered.
seq_lengths: int, it must be a list or 1-D. For example seq_lengths= [1, 2, 3]
seq_axis: int, it is a scalar, not a list, which is different from tf.reverse()
We will use an example to show you how to use this function.
Suppose you have a tensor, the shape of it is (batch_size, time_step, dim), you should reverse it based on time_step.
Create a tensor with (3, 3, 4)
where batch_size = 3, time_step = 3, dim = 4.
It means the length of seq_lengths should be time_step and the maximum number in seq_lengths shoud be lower than time_step.
For example:
seq_lengths = [1, 2, 3] is valid.
seq_lengths = [1, 2, 4] is invalid, 4 > time_step.
seq_lengths = [1, 2] is invalid.
Then we will create a 3 * 3 * 4 tensor.
import tensorflow as tf import numpy as np x = tf.Variable(np.array(range(36)), dtype = np.float32, name = 'x') x = tf.reshape(x, [3, 3, 4])
x will be:
[[[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]] [[12. 13. 14. 15.] [16. 17. 18. 19.] [20. 21. 22. 23.]] [[24. 25. 26. 27.] [28. 29. 30. 31.] [32. 33. 34. 35.]]]
Set the seq_lengths
As discussed above, we can set it to be [1, 2, 3]
seq_lengths = [1, 2, 3]
Reverse tensor by seq_lenghts based on time_step
x1 = tf.reverse_sequence(x, seq_lengths, seq_axis = 1)
The axis of time_step is 1, so we set seq_axis = 1.
Output x1.
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(x)) print(sess.run(x1))
Then we will find x1 will be:
[[[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]] [[16. 17. 18. 19.] [12. 13. 14. 15.] [20. 21. 22. 23.]] [[32. 33. 34. 35.] [28. 29. 30. 31.] [24. 25. 26. 27.]]]
Compare x with x1
We can compare x and x1 as below: