TensorFlow tf.nn.dynamic_rnn() is often used to create lstm or rnn network in deep learning, in this tutorial, we will discuss this function for tensorflow beginners.
Syntax of tf.nn.dynamic_rnn()
tf.nn.dynamic_rnn( cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None )
Creates a recurrent neural network specified by RNNCell cell
Parameters explained
cell: An instance of RNNCell. For example, a lstmcell object
inputs: The RNN inputs. This parameter is very important, we should notice the shape of it.
Condition | Inputs shape |
time_major == False | [batch_size, max_time, …] |
time_major == True | [max_time, batch_size, …] |
sequence_length: (optional) An int32/int64 vector sized [batch_size]. It is very important in nlp, when you are using lstm to process documents, the length of each document is not the same, we should use it to save the length of each doucment. For example:
sequence_length = [23, 23, 12,……, 45]
initial_state: (optional) An initial state for the RNN, we often initialize it with code:
init_state = cell.zero_state(batch_size, dtype=tf.float32) initial_state=init_state
time_major: control the shape format of the inputs and outputs.
Output of tf.nn.dynamic_rnn()
tf.nn.dynamic_rnn() returns a pair (outputs, state)
outputs is controled by time_major.
Condition | Outputs shape |
time_major == False | [batch_size, max_time, cell.output_size] |
time_major == True | [max_time, batch_size, cell.output_size] |
Here is an example for using it to create a lstm network
batch_size = 4 input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32) cell = tf.nn.rnn_cell.LSTMCell(10, forget_bias=1.0, state_is_tuple=True) init_state = cell.zero_state(batch_size, dtype=tf.float32) output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state, time_major= True)
From code above, time_major = True, batch_size = 4, lstm cell unit_num = 10, max_time = 3
The inputs of lstm is: [3, 4, 6]
The output of lstm is: [3, 4, 10]