TensorFlow tf.nn.batch_normalization() function can normalize a layer in batch. In this tutorial, we will use some examples to show you how to use it.
tf.nn.batch_normalization()
tf.nn.batch_normalization() is defined as:
tf.nn.batch_normalization( x, mean, variance, offset, scale, variance_epsilon, name=None )
It can normalize input \(x\) with \(mean\) and \(variance\).
The normalization equation is below in TensorFlow:
\(y_i=\lambda(\frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}})+\beta\)
Here
\(\mu\) is mean
\(\sigma^2\) is variance
\(\beta\) is offset
\(\lambda\) is scale
As to batch normalization, it is implemented differently in PyTorch and TensorFlow. Here is the tutorial:
Understand Batch Normalization: A Beginner Explain – Machine Learning Tutorial
How to use tf.nn.batch_normalization()?
In order to use tf.nn.batch_normalization(), we should compute the mean and variance of input \(x\). We can use tensorflow tf.nn.moments() to get them.
Here is an example:
import tensorflow as tf x1 = tf.convert_to_tensor( [[[18.369314, 2.6570225, 20.402943], [10.403599, 2.7813416, 20.794857]], [[19.0327, 2.6398268, 6.3894367], [3.921237, 10.761424, 2.7887821]], [[11.466338, 20.210938, 8.242946], [22.77081, 11.555874, 11.183836]], [[8.976935, 10.204252, 11.20231], [-7.356888, 6.2725096, 1.1952505]]])
Here we have created a input \(x1\), then we will compute its mean and variance.
mean_x, std_x = tf.nn.moments(x1, axes = 2, keep_dims=True)
We should notice \(axes = 2\), which means we will normaize input \(x1\). However, it is not batch normalization, it is layer normalization.
v1 = tf.nn.batch_normalization(x1, mean_x, std_x, None, None, 1e-12) with tf.Session() as sess1: sess1.run(tf.global_variables_initializer()) print(sess1.run(v1))
Run this code, we will get result:
[[[ 0.574993 -1.4064413 0.8314482 ] [-0.12501884 -1.1574404 1.2824591 ]] [[ 1.3801125 -0.95738953 -0.422723 ] [-0.5402142 1.4019756 -0.86176133]] [[-0.36398554 1.3654773 -1.0014919 ] [ 1.4136491 -0.67222667 -0.7414224 ]] [[-1.2645674 0.08396816 1.1806011 ] [-1.3146634 1.108713 0.20595042]]]