We usually use Bi-LSTM or Bi-GRU to model sequence. How to get its output? In this tutorial, we will discuss this topic.
A Bi-LSTM or Bi-GRU looks like below:
As to Bi-LSTM, it contains a forward LSTM and backword LSTM.
How to create Bi-LSTM in TensorFlow?
In tensroflow, we can use tf.nn.bidirectional_dynamic_rnn() to create a Bi-LSTM network.
Here is an example:
import tensorflow as tf import numpy as np #5*4*10 inputs = tf.Variable(tf.truncated_normal([3, 5, 100], stddev=0.1), name="inputs") hidden_size = 10 outputs, state = tf.nn.bidirectional_dynamic_rnn( cell_fw=tf.nn.rnn_cell.LSTMCell(hidden_size, forget_bias=1.0), cell_bw=tf.nn.rnn_cell.LSTMCell(hidden_size, forget_bias=1.0), inputs=inputs, sequence_length= None, dtype=tf.float32, scope='bilstm_doc_word' )
You should notice the inputs tensor of Bi-LSTM is [batch_size, max_seq_length, dim]. In this example, it is [3, 5, 100].
How to get the output of Bi-LSTM?
There are three types of outputs in Bi-LSTM.
Type 1: concatenating the output of forward and backward LSTM.
outputs_merge = tf.concat(outputs, 2)
Then you will get a output with shape [batch_size, max_seq_length, 2*hidden_dim], here hidden_dim is the hidden state size in lstm.
As to example above, you will find outputs_merge is [3, 5, 2*10]
Type 2: Get the mean output of the forward and backward LSTM.
You can find how to do in this tutorial:
Average the Output of RNN/GRU/LSTM/BiLSTM for Variable Length Sequences – Deep Learning Tutorial
Type 3: Get the last hidden output of the forward and backward LSTM.
Here is an example:
last_forward = outputs[0][:,-1,:] # 3*5*10 last_backward = outputs[1][:,0,:] init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) np.set_printoptions(precision=4, suppress=True) f =sess.run(last_forward) b = sess.run(last_backward) m = sess.run(outputs_merge) print("last forward shape=", f.shape) print(f) print("last backward shape=", b.shape) print(b) print("concatenate forward and backward") print(m)
Run this code, you will see: