We can inherit tensorflow RNNCell to create a custom RNN and create some variables in build() function. However, we may get this error: ValueError: Initializer for variable is from inside a control-flow construct, such as a loop or conditional. When creating a variable inside a loop or conditional, use a lambda as the initializer.
In this tutorial, we will introduce you how to fix it.
For example:
class CustomCell(tf.nn.rnn_cell.RNNCell): def __init__(self, num_units, reuse = None, dtype = tf.float32, name = "custom_cell"): super(CustomCell, self).__init__(_reuse=reuse, dtype = dtype, name=name) self._num_units = num_units # the dimension of rnn cell @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def build(self, inputs_shape): inputs_dim = inputs_shape[-1].value self._w = tf.Variable(tf.truncated_normal([inputs_dim, self._num_units], stddev=0.1), name="weight") self._b = tf.Variable(tf.truncated_normal([self._num_units], stddev=0.1), name="bias") self.built = True def call(self, inputs, state): # call body # how to use previous rnn cell output and state to generate new output and hidden new_h = tf.tanh(tf.matmul(inputs+state, self._w) + self._b) new_c = tf.nn.swish(tf.matmul(state, self._w) + self._b) return new_h, new_c
If you use this CustomCell class, you will get this error:
How to fix this ValueError?
We can not create tensorflow variables in build() function using tf.Variable(). We should use self.add_variable().
For example:
class CustomCell(tf.nn.rnn_cell.RNNCell): def __init__(self, num_units, reuse = None, dtype = tf.float32, name = "custom_cell"): super(CustomCell, self).__init__(_reuse=reuse, dtype = dtype, name=name) self._num_units = num_units # the dimension of rnn cell @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def build(self, inputs_shape): inputs_dim = inputs_shape[-1].value self._w = self.add_variable(name="weight", shape = [inputs_dim, self._num_units], initializer = tf.glorot_normal_initializer(), dtype = tf.float32) self._b = self.add_variable(name="bias", shape=[self._num_units], initializer=tf.glorot_normal_initializer(), dtype=tf.float32) self.built = True def call(self, inputs, state): # call body # how to use previous rnn cell output and state to generate new output and hidden new_h = tf.tanh(tf.matmul(inputs+state, self._w) + self._b) new_c = tf.nn.swish(tf.matmul(state, self._w) + self._b) return new_h, new_c
Then we can use code below to evaluate it.
size = 100 inputs = tf.Variable(tf.truncated_normal([3, 20, 100], stddev=0.1), name="inputs") input_lengths = tf.Variable(tf.truncated_normal([3, 20], stddev=0.1), name="inputs_length") _fw_cell =CustomCell(size, name='encoder_fw_') _bw_cell =CustomCell(size, name='encoder_bw') with tf.variable_scope("Custom_BiLSTM"): outputs, (fw_state, bw_state) = tf.nn.bidirectional_dynamic_rnn( _fw_cell, _bw_cell, inputs, sequence_length=None, dtype=tf.float32, swap_memory=True) outputs = tf.concat(outputs, axis=2) # Concat and return forward + backward outputs init = tf.global_variables_initializer() init_local = tf.local_variables_initializer() with tf.Session() as sess: sess.run([init, init_local]) np.set_printoptions(precision=4, suppress=True) f =sess.run([inputs, outputs])
Run this code, we will get:
f shape= (3, 20, 100) (3, 20, 200) [array([[[ 0.1409, 0.107 , 0.0258, ..., 0.0281, -0.0612, 0.0525], [ 0.0652, -0.0202, 0.0169, ..., -0.1956, 0.0543, -0.0334], [-0.0559, 0.1613, 0.0257, ..., 0.0858, 0.1105, -0.0963], ..., [-0.0716, -0.0563, -0.0451, ..., 0.1238, -0.0111, -0.0465], [-0.0484, 0.0344, -0.0566, ..., -0.1707, -0.0705, 0.01 ], [-0.0037, -0.0209, -0.0565, ..., 0.0233, 0.0548, 0.1174]]], dtype=float32), array([[[-0.0311, -0.246 , -0.0412, ..., -0.0077, -0.0249, 0.0986], [-0.1012, -0.1396, -0.0219, ..., -0.0893, 0.128 , 0.1197], [ 0.0678, -0.2054, -0.0608, ..., -0.0068, -0.0627, 0.1589], ..., [ 0.0128, -0.1721, 0.0594, ..., 0.1566, 0.1048, -0.104 ], [-0.0187, -0.2965, 0.0667, ..., -0.0013, 0.0132, 0.046 ], [ 0.0082, -0.1131, 0.0417, ..., 0.068 , 0.2191, 0.0546]]], dtype=float32)]
This value error is fixed.