Zoneout is proposed in paper: Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations. It is also used in Tacotron 2. In this tutorial, we will introduce what it is and how to implement it using tensorflow.
What is zoneout?
Zoneout can force some hidden units to maintain their previous values like dropout at each timestep in LSTM.
A LSTM is defined as follows:
To learn more on LSTM, you can read:
Understand Long Short-Term Memory Network(LSTM) – LSTM Tutorial
In order to regularize LSTM, we can add dropout at each timestep. For example:
Here \(d_t\) is a zero-mask.
However, zoneout use a different method.
Here \(d^c_t\) and \(d^h_t\) is zoneout mask, they are gate mechanism to control how much previous hidden out and state can be used in current timestep.
Here is a variant of zoneout.
How to implement zoneout in tensorflow?
Here is an example:
import tensorflow as tf import numpy as np class ZoneoutLSTMCell(tf.nn.rnn_cell.RNNCell): '''Wrapper for tf LSTM to create Zoneout LSTM Cell inspired by: https://github.com/teganmaharaj/zoneout/blob/master/zoneout_tensorflow.py Published by one of 'https://arxiv.org/pdf/1606.01305.pdf' paper writers. Many thanks to @Ondal90 for pointing this out. You sir are a hero! ''' def __init__(self, num_units, is_training, zoneout_factor_cell=0., zoneout_factor_output=0., state_is_tuple=True, name=None): '''Initializer with possibility to set different zoneout values for cell/hidden states. ''' zm = min(zoneout_factor_output, zoneout_factor_cell) zs = max(zoneout_factor_output, zoneout_factor_cell) if zm < 0. or zs > 1.: raise ValueError('One/both provided Zoneout factors are not in [0, 1]') self._cell = tf.nn.rnn_cell.LSTMCell(num_units, state_is_tuple=state_is_tuple, name=name) self._zoneout_cell = zoneout_factor_cell self._zoneout_outputs = zoneout_factor_output self.is_training = is_training self.state_is_tuple = state_is_tuple @property def state_size(self): return self._cell.state_size @property def output_size(self): return self._cell.output_size def zoneout(self, new_c, prev_c, new_h, prev_h): c = (1 - self._zoneout_cell) * tf.nn.dropout(new_c - prev_c, (1 - self._zoneout_cell)) + prev_c h = (1 - self._zoneout_outputs) * tf.nn.dropout(new_h - prev_h, (1 - self._zoneout_outputs)) + prev_h return c, h def normout(self, new_c, prev_c,new_h, prev_h ): c = (1 - self._zoneout_cell) * new_c + self._zoneout_cell * prev_c h = (1 - self._zoneout_outputs) * new_h + self._zoneout_outputs * prev_h return c, h def __call__(self, inputs, state, scope=None): '''Runs vanilla LSTM Cell and applies zoneout. ''' #Apply vanilla LSTM output, new_state = self._cell(inputs, state, scope) if self.state_is_tuple: (prev_c, prev_h) = state (new_c, new_h) = new_state else: num_proj = self._cell._num_units if self._cell._num_proj is None else self._cell._num_proj prev_c = tf.slice(state, [0, 0], [-1, self._cell._num_units]) prev_h = tf.slice(state, [0, self._cell._num_units], [-1, num_proj]) new_c = tf.slice(new_state, [0, 0], [-1, self._cell._num_units]) new_h = tf.slice(new_state, [0, self._cell._num_units], [-1, num_proj]) # Apply zoneout c, h = tf.cond(self.is_training, lambda: self.zoneout(new_c, prev_c, new_h, prev_h), lambda: self.normout(new_c, prev_c, new_h, prev_h)) new_state = tf.nn.rnn_cell.LSTMStateTuple(c, h) if self.state_is_tuple else tf.concat(1, [c, h]) return output, new_state
Then we can start to evaluate it.
size = 128 zoneout = 0.1 inputs = tf.Variable(tf.truncated_normal([3, 20, 100], stddev=0.1), name="inputs") input_lengths = tf.Variable(tf.truncated_normal([3, 20], stddev=0.1), name="inputs") is_training = tf.Variable(False, dtype= tf.bool) _fw_cell =ZoneoutLSTMCell(size, is_training, zoneout_factor_cell=zoneout, zoneout_factor_output=zoneout, name='encoder_fw_LSTM') _bw_cell =ZoneoutLSTMCell(size, is_training, zoneout_factor_cell=zoneout, zoneout_factor_output=zoneout, name='encoder_bw_LSTM') with tf.variable_scope("Zoneout_BiLSTM"): outputs, (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn( _fw_cell, _bw_cell, inputs, sequence_length=None, dtype=tf.float32, swap_memory=True) outputs = tf.concat(outputs, axis=2) # Concat and return forward + backward outputs init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) np.set_printoptions(precision=4, suppress=True) f =sess.run([inputs, outputs]) print("f shape=", f[0].shape,f[1].shape) print(f)
Run this code, we will get:
f shape= (3, 20, 100) (3, 20, 256) [array([[[ 0.0553, -0.1388, 0.0857, ..., -0.098 , -0.0475, 0.0474], [ 0.0308, 0.1614, 0.1135, ..., -0.187 , 0.0044, 0.1501], [-0.028 , -0.1262, -0.0348, ..., 0.0382, 0.053 , -0.1148], ..., [-0.0047, 0.0642, -0.191 , ..., -0.0884, -0.0713, 0.0217], [-0.0238, -0.0397, -0.0248, ..., -0.0297, -0.0909, -0.0292], [ 0.061 , 0.0694, -0.0787, ..., -0.0181, -0.0957, 0.1915]]], dtype=float32), array([[[ 0.021 , -0.0018, 0.0255, ..., -0.0204, 0.0113, -0.0092], [ 0.0059, -0.0008, 0.0118, ..., -0.0149, 0.0119, 0.0059], [ 0.0041, 0.0112, -0.0074, ..., -0.0207, 0.0161, -0.0113], ..., [-0.0034, 0.013 , 0.0065, ..., -0.0145, 0.0198, 0.0044], [ 0.0015, 0.0066, 0.0003, ..., -0.0148, 0.0285, -0.0098], [ 0.0115, 0.0081, 0.0109, ..., -0.0101, 0.012 , -0.0094]]], dtype=float32)] Process finished with exit code 0
In this example code, the input is 3*20*100, we will get a output with 3*20*256.