In tensorflow, we can create a custom RNN by inheriting tf.nn.rnn_cell.RNNCell or tensorflow.contrib.rnn.RNNCell. However, you may find some rnn classes have implemented __call__(), some implemented call(). Why? In this tutorial, we will tell you the reason.
To understand __call__() in python, you can read:
Python __call__(): Call Function with Dynamic Parameters – Python Tutorial
__call__() and call() in TensorFlow RNNCell
To create a custom RNN, we may do as follows:
from tensorflow.contrib.rnn import RNNCell class DecoderPrenetWrapper(RNNCell):
Here DecoderPrenetWrapper is the child class of RNNCell.
You can create your own __call__() or call().
Because:
DecoderPrenetWrapper<-tensorflow.contrib.rnn.RNNCell<-base_layer.Layer
In base_layer.Layer, a call() function is defined and called by its __call__().
For example:
@doc_controls.for_subclass_implementers def call(self, inputs, **kwargs): # pylint: disable=unused-argument """This is where the layer's logic lives. Arguments: inputs: Input tensor, or list/tuple of input tensors. **kwargs: Additional keyword arguments. Returns: A tensor or list/tuple of tensors. """ return inputs
It means we can overwrite call() function to create a custom RNN.