Conformer network is proposed in paper Conformer: Convolution-augmented Transformer for Speech Recognition. In this tutorial, we will introduce its feed forword module and how to implement it using tensorflow.
Conformer Feed Forward Module
The structrue of this module looks like:

The first linear layer uses an expansion factor of 4 and the second linear layer projects it back to the model dimension. We use swish activation and a pre-norm residual units in feed forward module.
You should notice: Layernorm is applied based on pre-norm.
Post-Norm and Pre-Norm Residual Units Explained – Deep Learning Tutorial
How to implement feed forward module?
In this tutorial, we will use a tensorflow 1.x to implement it.
We should notice 1/2 * FFN in conformer block.

Here is an example code:
import tensorflow as tf
class FeedForward_Module():
def __init__(self, expansion_factor = 4.0, output_reduction_factor = 0.5):
self.expansion_factor = expansion_factor # enlarge
self.output_reduction_factor = output_reduction_factor
#inputs: batch_size * dim, 2dim
def call(self, inputs, encoder_dim = 512, dropout_rate = 0.4):
#1. LayerNormalization
outputs = tf.contrib.layers.layer_norm(inputs=inputs, begin_norm_axis=-1, begin_params_axis=-1)
#2. Dense
outputs = tf.layers.dense(inputs=outputs, units= encoder_dim * self.expansion_factor)
#3. swish
outputs = tf.nn.swish(outputs)
#4. dropout
outputs = tf.layers.dropout(outputs, dropout_rate)
#5. Dense
outputs = tf.layers.dense(inputs=outputs, units=encoder_dim)
#6. Dropout
outputs = tf.layers.dropout(outputs, dropout_rate)
#7. residual units
outputs += self.output_reduction_factor * inputs
return outputs
- import tensorflow as tf
- class FeedForward_Module():
- def __init__(self, expansion_factor = 4.0, output_reduction_factor = 0.5):
- self.expansion_factor = expansion_factor # enlarge
- self.output_reduction_factor = output_reduction_factor
- #inputs: batch_size * dim, 2dim
- def call(self, inputs, encoder_dim = 512, dropout_rate = 0.4):
- #1. LayerNormalization
- outputs = tf.contrib.layers.layer_norm(inputs=inputs, begin_norm_axis=-1, begin_params_axis=-1)
- #2. Dense
- outputs = tf.layers.dense(inputs=outputs, units= encoder_dim * self.expansion_factor)
- #3. swish
- outputs = tf.nn.swish(outputs)
- #4. dropout
- outputs = tf.layers.dropout(outputs, dropout_rate)
- #5. Dense
- outputs = tf.layers.dense(inputs=outputs, units=encoder_dim)
- #6. Dropout
- outputs = tf.layers.dropout(outputs, dropout_rate)
- #7. residual units
- outputs += self.output_reduction_factor * inputs
- return outputs
import tensorflow as tf
class FeedForward_Module():
def __init__(self, expansion_factor = 4.0, output_reduction_factor = 0.5):
self.expansion_factor = expansion_factor # enlarge
self.output_reduction_factor = output_reduction_factor
#inputs: batch_size * dim, 2dim
def call(self, inputs, encoder_dim = 512, dropout_rate = 0.4):
#1. LayerNormalization
outputs = tf.contrib.layers.layer_norm(inputs=inputs, begin_norm_axis=-1, begin_params_axis=-1)
#2. Dense
outputs = tf.layers.dense(inputs=outputs, units= encoder_dim * self.expansion_factor)
#3. swish
outputs = tf.nn.swish(outputs)
#4. dropout
outputs = tf.layers.dropout(outputs, dropout_rate)
#5. Dense
outputs = tf.layers.dense(inputs=outputs, units=encoder_dim)
#6. Dropout
outputs = tf.layers.dropout(outputs, dropout_rate)
#7. residual units
outputs += self.output_reduction_factor * inputs
return outputs
You may wonder why we use dropout layer but batch normalization after dense layer, you can read this tutorial:
Dropout vs Batch Normalization – Which is Better for Multilayered Neural Network
Then we can evaluate this module.
if __name__ == "__main__":
import numpy as np
batch_size, dim = 3, 30
encoder_dim = 30
inputs = tf.random.uniform((batch_size, dim), minval=-10, maxval=10)
feednet = FeedForward_Module()
outputs = feednet.call(inputs, encoder_dim, dropout_rate = 0.4)
init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
sess.run([init, init_local])
np.set_printoptions(precision=4, suppress=True)
a = sess.run(outputs)
print(a.shape)
print(a)
- if __name__ == "__main__":
- import numpy as np
- batch_size, dim = 3, 30
- encoder_dim = 30
- inputs = tf.random.uniform((batch_size, dim), minval=-10, maxval=10)
- feednet = FeedForward_Module()
- outputs = feednet.call(inputs, encoder_dim, dropout_rate = 0.4)
- init = tf.global_variables_initializer()
- init_local = tf.local_variables_initializer()
- with tf.Session() as sess:
- sess.run([init, init_local])
- np.set_printoptions(precision=4, suppress=True)
- a = sess.run(outputs)
- print(a.shape)
- print(a)
if __name__ == "__main__":
import numpy as np
batch_size, dim = 3, 30
encoder_dim = 30
inputs = tf.random.uniform((batch_size, dim), minval=-10, maxval=10)
feednet = FeedForward_Module()
outputs = feednet.call(inputs, encoder_dim, dropout_rate = 0.4)
init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
sess.run([init, init_local])
np.set_printoptions(precision=4, suppress=True)
a = sess.run(outputs)
print(a.shape)
print(a)
Run this code, we will see:
(3, 30)
[[ 3.245 5.2575 2.0267 -2.7365 -3.1796 2.8685 -1.9953 2.0717 -1.6839
0.4027 2.4781 -1.7986 -3.9386 5.2049 -2.4076 3.6563 0.3101 4.4532
-2.4095 3.876 5.1791 1.7626 -1.6007 0.8386 -2.2012 -0.0357 4.9984
-3.5259 -3.5232 1.5579]
[ 0.6908 -1.8504 3.9558 3.0444 -2.5875 -2.0062 -3.7148 4.6892 -5.0732
-0.8936 2.5074 -4.5067 2.5742 -4.1987 -1.786 2.351 5.8096 -3.1629
-3.3842 1.9663 -1.1548 0.8982 -3.4875 -4.9175 4.2789 4.0092 -1.5448
-0.9469 1.8976 0.1782]
[ 1.3729 4.4757 -2.5322 4.694 -4.2836 4.0316 -1.2085 -0.6076 0.0301
2.1905 -3.5824 -3.7651 -2.0002 -3.8068 1.5397 -1.0683 -0.2866 3.4809
1.4468 -2.5218 3.7157 0.1056 -1.3569 1.7638 -2.8352 -4.0523 3.7236
-4.3653 4.2616 -0.0585]]
- (3, 30)
- [[ 3.245 5.2575 2.0267 -2.7365 -3.1796 2.8685 -1.9953 2.0717 -1.6839
- 0.4027 2.4781 -1.7986 -3.9386 5.2049 -2.4076 3.6563 0.3101 4.4532
- -2.4095 3.876 5.1791 1.7626 -1.6007 0.8386 -2.2012 -0.0357 4.9984
- -3.5259 -3.5232 1.5579]
- [ 0.6908 -1.8504 3.9558 3.0444 -2.5875 -2.0062 -3.7148 4.6892 -5.0732
- -0.8936 2.5074 -4.5067 2.5742 -4.1987 -1.786 2.351 5.8096 -3.1629
- -3.3842 1.9663 -1.1548 0.8982 -3.4875 -4.9175 4.2789 4.0092 -1.5448
- -0.9469 1.8976 0.1782]
- [ 1.3729 4.4757 -2.5322 4.694 -4.2836 4.0316 -1.2085 -0.6076 0.0301
- 2.1905 -3.5824 -3.7651 -2.0002 -3.8068 1.5397 -1.0683 -0.2866 3.4809
- 1.4468 -2.5218 3.7157 0.1056 -1.3569 1.7638 -2.8352 -4.0523 3.7236
- -4.3653 4.2616 -0.0585]]
(3, 30)
[[ 3.245 5.2575 2.0267 -2.7365 -3.1796 2.8685 -1.9953 2.0717 -1.6839
0.4027 2.4781 -1.7986 -3.9386 5.2049 -2.4076 3.6563 0.3101 4.4532
-2.4095 3.876 5.1791 1.7626 -1.6007 0.8386 -2.2012 -0.0357 4.9984
-3.5259 -3.5232 1.5579]
[ 0.6908 -1.8504 3.9558 3.0444 -2.5875 -2.0062 -3.7148 4.6892 -5.0732
-0.8936 2.5074 -4.5067 2.5742 -4.1987 -1.786 2.351 5.8096 -3.1629
-3.3842 1.9663 -1.1548 0.8982 -3.4875 -4.9175 4.2789 4.0092 -1.5448
-0.9469 1.8976 0.1782]
[ 1.3729 4.4757 -2.5322 4.694 -4.2836 4.0316 -1.2085 -0.6076 0.0301
2.1905 -3.5824 -3.7651 -2.0002 -3.8068 1.5397 -1.0683 -0.2866 3.4809
1.4468 -2.5218 3.7157 0.1056 -1.3569 1.7638 -2.8352 -4.0523 3.7236
-4.3653 4.2616 -0.0585]]