Understand tf.nn.dynamic_rnn() for TensorFlow Beginners – TensorFLow Tutorial

By | November 11, 2019

TensorFlow tf.nn.dynamic_rnn() is often used to create lstm or rnn network in deep learning, in this tutorial, we will discuss this function for tensorflow beginners.

Syntax of tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

Creates a recurrent neural network specified by RNNCell cell

Parameters explained

cell: An instance of RNNCell. For example, a lstmcell object

inputs: The RNN inputs. This parameter is very important, we should notice the shape of it.

Condition Inputs shape
time_major == False [batch_size, max_time, …]
time_major == True [max_time, batch_size, …]

sequence_length: (optional) An int32/int64 vector sized [batch_size]. It is very important in nlp, when you are using lstm to process documents, the length of each document is not the same, we should use it to save the length of each doucment. For example:

sequence_length = [23, 23, 12,……, 45]

initial_state: (optional) An initial state for the RNN, we often initialize it with code:

init_state = cell.zero_state(batch_size, dtype=tf.float32)
initial_state=init_state

time_major: control the shape format of the inputs and outputs.

Output of tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn() returns a pair (outputs, state)

outputs is controled by time_major.

Condition Outputs shape
time_major == False [batch_size, max_time, cell.output_size]
time_major == True [max_time, batch_size, cell.output_size]

Here is an example for using it to create a lstm network

batch_size = 4 
input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32)
cell = tf.nn.rnn_cell.LSTMCell(10, forget_bias=1.0, state_is_tuple=True)
init_state = cell.zero_state(batch_size, dtype=tf.float32)
output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state, time_major= True)

From code above,  time_major = True, batch_size = 4, lstm cell unit_num = 10, max_time = 3

The inputs of lstm is: [3, 4, 6]

The output of lstm is: [3, 4, 10]

Leave a Reply