Average the Output of RNN/GRU/LSTM/BiLSTM for Variable Length Sequences – Deep Learning Tutorial

By | April 8, 2021

We often use RNN/GRU/LSTM/BiLSTM to encode sequence. In order to get the output of these models. We can average outputs or use attention to compute. In this tutorial, we will introduce how to average their outputs.

Preliminary

Look at rnn model below:

lstm output mask

The output is\([h_0, h_1,…, h_t\).

How to get these outputs?

As to BiLSTM, we can use code below:

outputs, state = tf.nn.bidirectional_dynamic_rnn(
                    cell_fw=tf.nn.rnn_cell.LSTMCell(self.hidden_size, forget_bias=1.0),
                    cell_bw=tf.nn.rnn_cell.LSTMCell(self.hidden_size, forget_bias=1.0),
                    inputs=inputs,
                    sequence_length=word_in_sen_len,
                    dtype=tf.float32,
                    scope='bilstm_doc_word'
                )
outputs = tf.concat(outputs, 2)

In this code, inputs is a tensor with the shape [batch_size, time_step, dim]

Here is a tutorial on how to use tf.nn.bidirectional_dynamic_rnn().

An Introduction to How TensorFlow Bidirectional Dynamic RNN Process Variable Length Sequence – LSTM Tutorial

How to average the outputs?

To average the outputs, we should notice not all outputs are valid.

As to example above, if \(t = 50\), however, only \([h_0, h_1, h_2]\) are valid. We should average the outputs as:

\(\frac{h_0+h_1+h_2}{3}\)

It means we should use sequence_length to create a mske to hidden invalid outputs.

We will create a function to implement it.

def avg_lstm(inputs, length):
    inputs = tf.cast(inputs, tf.float32)
    batch_size = tf.shape(inputs)[0]
    time_step = tf.shape(inputs)[1]
    length = tf.reshape(length, [-1])
    mask = tf.reshape(tf.cast(tf.sequence_mask(length, time_step), tf.float32), tf.shape(inputs))
    inputs *= mask
    _sum = tf.reduce_sum(inputs, reduction_indices=1)
    length = tf.reshape(length, [-1, time_step])
    length = tf.cast(length , tf.float32) + 1e-9
    return _sum / length

Then, we can get the average output easily.

output = avg_lstm(outputs, word_in_sen_len)

In order to know how to use tf.sequence_mask(), you can view this tutorial:

Understand TensorFlow tf.sequence_mask(): Create a Mask Tensor to Shield Elements – TensorFlow Tutorial

Leave a Reply