Tutorial Example

Understand Batch Normalization: A Beginner Explain – Machine Learning Tutorial

Batch normalization is proposed in paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. In this tutorial, we will explain it for machine learning beginners.

What is Batch Normalization?

Batch Normalization aims to normalize a batch samples based on a normal distribution.

For example: There are 64 samples in a train step. Each sample is 1* 200, which mean we have a 64 * 200 matrix.

We can normalize this batch samples using batch normalization method.

\(y_i=\lambda(\frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}})+\beta\)

where \(\mu\) is the mean of samples, \(\sigma^2\) is the variance of samples, \(\lambda\) is the scale and \(\beta\) is the shift.

In order to know how to compute \(\mu\) and \(\sigma^2\), you can read:

Understand the Mean and Variance Computed in Batch Normalization – Machine Learning Tutorial

Batch Normalization implemented in pytorch and tensorflow

Batch Normalization implemented differently in pytorch and tensorflow, we compare them with table below:

pytorch tensorflow
equation \(y_i=\lambda(\frac{x_i-\mu}{\sqrt{\sigma^2}+\epsilon})+\beta\) \(y_i=\lambda(\frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}})+\beta\)
\(\epsilon\) 1e-5 1e-3
momentun 0.1 0.99

How to use batch normalization?

As to batch normalization, we should get the value of four variables. They are:

Variable Description How to get in tensorflow
\(\mu\) The mean of batch samples tf.nn.moments()
\(\sigma^2\) The variance of samples tf.nn.moments()
\(\lambda\) The scale Learned by training
\(\beta\) The shift parameter Learned by training

We should notice:

if \(\lambda = 1\) and \(\beta = 0\), batch normalization is standardization

if \(\lambda = \sigma\) and \(\beta = \mu\), It means we will do not use batch normalization.

In order to use batch normalization in our model, we can view this tutorial:

How to Update the Mean and Variance of Population and Test Sample in Batch Normalization – Machine Learning Tutorial