Understand the Mean and Variance Computed in Batch Normalization – Machine Learning Tutorial

By | December 3, 2020

In this tutorial, we will discuss how the mean and variance computed in batch normalization, which is very usefult to understand batch normalization.

For example:

There is a batch, the shape of it is (64, 200). There are three ways to compute the mean (\(\mu\)) and variance \(\sigma\) .

Way 1:

Compute the mean (\(\mu\)) and variance \(\sigma\) on 12800 data. You will get two scalar.

Calculate the Mean and Variance in batch normalization - way 1

Way 2:

Compuate the mean (\(\mu\)) and variance \(\sigma\) on axis = 1. You will get two 1*64 vectors.

Calculate the Mean and Variance in batch normalization - way 2

Way 3:

Compuate the mean (\(\mu\)) and variance \(\sigma\) on axis = 0. You will get two 1*200 vectors.

Calculate the Mean and Variance in batch normalization

If we will normalize the data on axis = 1, which way is used in batch normalization?

The answer is way 3.

We will use an example to show you this answer.

Look at code example below:

import os
import numpy as np
import tensorflow as tf

cap_map = tf.convert_to_tensor(np.array([[-1,3,2], [-3,1,3],[2,-7,4],[5,7, 6]], dtype = float), dtype = tf.float32)

def norm(xs):
    fc_mean, fc_var = tf.nn.moments(
            xs,
            axes = 0,
            keep_dims=True
        )
    epsilon =0.001
    xs = tf.nn.batch_normalization(xs, fc_mean, fc_var, 0.0, 1.0, epsilon)
    return xs

n1= tf.layers.batch_normalization(cap_map, axis = 1, training=True,  scale=False)
n2 = norm(cap_map)

init = tf.initialize_all_variables()                                                                             
with tf.Session() as sess:
    sess.run(init)
    print (sess.run(n1))
    print (sess.run(n2))

cap_map is a 4*3 tensor, which means the batch size = 4.

We will normalize the data on axis = 1.

tf.layers.batch_normalization() will use batch normalization to normalize the data.

n1 is:

[[-0.57731885  0.39222473 -1.1829456 ]
 [-1.2371118   0.         -0.5069766 ]
 [ 0.4123706  -1.568899    0.16899228]
 [ 1.40206     1.1766742   1.5209303 ]]

tf.nn.batch_normalization() function normalize the data based on tensor mean and variance. In this example, we will calculate them on axis = 0.

Run this code, we can find n2 is:

[[-0.57731885  0.39222473 -1.1829456 ]
 [-1.2371118   0.         -0.5069766 ]
 [ 0.4123706  -1.568899    0.16899228]
 [ 1.40206     1.1766742   1.5209303 ]]

n1 = n2

It means the way 3 is implemented in batch normalization.

Leave a Reply