Regularizing LSTM at Each Timestep with Zoneout – Deep Learning Tutorial

By | March 30, 2022

Zoneout is proposed in paper: Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations. It is also used in Tacotron 2. In this tutorial, we will introduce what it is and how to implement it using tensorflow.

What is zoneout?

Zoneout can force some hidden units to maintain their previous values like dropout at each timestep in LSTM.

A LSTM is defined as follows:

Standard LSTM Equations

To learn more on LSTM, you can read:

Understand Long Short-Term Memory Network(LSTM) – LSTM Tutorial

In order to regularize LSTM, we can add dropout at each timestep. For example:

add dropout for each timestep in lstm

Here \(d_t\) is a zero-mask.

However, zoneout use a different method.

zoneout equations

Here \(d^c_t\) and \(d^h_t\) is zoneout mask, they are gate mechanism to control how much previous hidden out and state can be used in current timestep.

Here is a variant of zoneout.

a variant of zoneout

How to implement zoneout in tensorflow?

Here is an example:

import tensorflow as tf
import numpy as np

class ZoneoutLSTMCell(tf.nn.rnn_cell.RNNCell):
	'''Wrapper for tf LSTM to create Zoneout LSTM Cell

	inspired by:
	https://github.com/teganmaharaj/zoneout/blob/master/zoneout_tensorflow.py

	Published by one of 'https://arxiv.org/pdf/1606.01305.pdf' paper writers.

	Many thanks to @Ondal90 for pointing this out. You sir are a hero!
	'''
	def __init__(self, num_units, is_training, zoneout_factor_cell=0., zoneout_factor_output=0., state_is_tuple=True, name=None):
		'''Initializer with possibility to set different zoneout values for cell/hidden states.
		'''
		zm = min(zoneout_factor_output, zoneout_factor_cell)
		zs = max(zoneout_factor_output, zoneout_factor_cell)

		if zm < 0. or zs > 1.:
			raise ValueError('One/both provided Zoneout factors are not in [0, 1]')

		self._cell = tf.nn.rnn_cell.LSTMCell(num_units, state_is_tuple=state_is_tuple, name=name)
		self._zoneout_cell = zoneout_factor_cell
		self._zoneout_outputs = zoneout_factor_output
		self.is_training = is_training
		self.state_is_tuple = state_is_tuple

	@property
	def state_size(self):
		return self._cell.state_size

	@property
	def output_size(self):
		return self._cell.output_size

	def zoneout(self, new_c, prev_c, new_h, prev_h):
		c = (1 - self._zoneout_cell) * tf.nn.dropout(new_c - prev_c, (1 - self._zoneout_cell)) + prev_c
		h = (1 - self._zoneout_outputs) * tf.nn.dropout(new_h - prev_h, (1 - self._zoneout_outputs)) + prev_h
		return c, h
	def normout(self, new_c, prev_c,new_h, prev_h ):
		c = (1 - self._zoneout_cell) * new_c + self._zoneout_cell * prev_c
		h = (1 - self._zoneout_outputs) * new_h + self._zoneout_outputs * prev_h
		return c, h

	def __call__(self, inputs, state, scope=None):
		'''Runs vanilla LSTM Cell and applies zoneout.
		'''
		#Apply vanilla LSTM
		output, new_state = self._cell(inputs, state, scope)

		if self.state_is_tuple:
			(prev_c, prev_h) = state
			(new_c, new_h) = new_state
		else:
			num_proj = self._cell._num_units if self._cell._num_proj is None else self._cell._num_proj
			prev_c = tf.slice(state, [0, 0], [-1, self._cell._num_units])
			prev_h = tf.slice(state, [0, self._cell._num_units], [-1, num_proj])
			new_c = tf.slice(new_state, [0, 0], [-1, self._cell._num_units])
			new_h = tf.slice(new_state, [0, self._cell._num_units], [-1, num_proj])

		# Apply zoneout
		c, h = tf.cond(self.is_training, lambda: self.zoneout(new_c, prev_c, new_h, prev_h), lambda: self.normout(new_c, prev_c, new_h, prev_h))

		new_state = tf.nn.rnn_cell.LSTMStateTuple(c, h) if self.state_is_tuple else tf.concat(1, [c, h])

		return output, new_state

Then we can start to evaluate it.

size = 128
zoneout = 0.1
inputs = tf.Variable(tf.truncated_normal([3, 20, 100], stddev=0.1), name="inputs")
input_lengths = tf.Variable(tf.truncated_normal([3, 20], stddev=0.1), name="inputs")
is_training = tf.Variable(False, dtype= tf.bool)
_fw_cell =ZoneoutLSTMCell(size, is_training,
			zoneout_factor_cell=zoneout,
			zoneout_factor_output=zoneout,
			name='encoder_fw_LSTM')
_bw_cell =ZoneoutLSTMCell(size, is_training,
			zoneout_factor_cell=zoneout,
			zoneout_factor_output=zoneout,
			name='encoder_bw_LSTM')
with tf.variable_scope("Zoneout_BiLSTM"):
    outputs, (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn(
		_fw_cell,
		_bw_cell,
		inputs,
		sequence_length=None,
		dtype=tf.float32,
		swap_memory=True)

    outputs = tf.concat(outputs, axis=2)  # Concat and return forward + backward outputs

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    np.set_printoptions(precision=4, suppress=True)
    f =sess.run([inputs, outputs])

    print("f shape=", f[0].shape,f[1].shape)
    print(f)

Run this code, we will get:

f shape= (3, 20, 100) (3, 20, 256)
[array([[[ 0.0553, -0.1388,  0.0857, ..., -0.098 , -0.0475,  0.0474],
        [ 0.0308,  0.1614,  0.1135, ..., -0.187 ,  0.0044,  0.1501],
        [-0.028 , -0.1262, -0.0348, ...,  0.0382,  0.053 , -0.1148],
        ...,
        [-0.0047,  0.0642, -0.191 , ..., -0.0884, -0.0713,  0.0217],
        [-0.0238, -0.0397, -0.0248, ..., -0.0297, -0.0909, -0.0292],
        [ 0.061 ,  0.0694, -0.0787, ..., -0.0181, -0.0957,  0.1915]]],
      dtype=float32), array([[[ 0.021 , -0.0018,  0.0255, ..., -0.0204,  0.0113, -0.0092],
        [ 0.0059, -0.0008,  0.0118, ..., -0.0149,  0.0119,  0.0059],
        [ 0.0041,  0.0112, -0.0074, ..., -0.0207,  0.0161, -0.0113],
        ...,
        [-0.0034,  0.013 ,  0.0065, ..., -0.0145,  0.0198,  0.0044],
        [ 0.0015,  0.0066,  0.0003, ..., -0.0148,  0.0285, -0.0098],
        [ 0.0115,  0.0081,  0.0109, ..., -0.0101,  0.012 , -0.0094]]],
      dtype=float32)]

Process finished with exit code 0

In this example code, the input is 3*20*100, we will get a output with 3*20*256.

Leave a Reply