Understand tf.nn.batch_normalization(): Normalize a Layer – TensorFlow Tutorial

By | May 24, 2021

TensorFlow tf.nn.batch_normalization() function can normalize a layer in batch. In this tutorial, we will use some examples to show you how to use it.

tf.nn.batch_normalization()

tf.nn.batch_normalization() is defined as:

tf.nn.batch_normalization(
    x,
    mean,
    variance,
    offset,
    scale,
    variance_epsilon,
    name=None
)

It can normalize input \(x\) with \(mean\) and \(variance\).

The normalization equation is below in TensorFlow:

\(y_i=\lambda(\frac{x_i-\mu}{\sqrt{\sigma^2+\epsilon}})+\beta\)

Here

\(\mu\) is mean

\(\sigma^2\) is variance

\(\beta\) is offset

\(\lambda\) is scale

As to batch normalization, it is implemented differently in PyTorch and TensorFlow. Here is the tutorial:

Understand Batch Normalization: A Beginner Explain – Machine Learning Tutorial

How to use tf.nn.batch_normalization()?

In order to use tf.nn.batch_normalization(), we should compute the mean and variance of input \(x\). We can use tensorflow tf.nn.moments() to get them.

Here is an example:

import tensorflow as tf
 
x1 = tf.convert_to_tensor(
    [[[18.369314, 2.6570225, 20.402943],
      [10.403599, 2.7813416, 20.794857]],
     [[19.0327, 2.6398268, 6.3894367],
      [3.921237, 10.761424, 2.7887821]],
     [[11.466338, 20.210938, 8.242946],
      [22.77081, 11.555874, 11.183836]],
     [[8.976935, 10.204252, 11.20231],
      [-7.356888, 6.2725096, 1.1952505]]])

Here we have created a input \(x1\), then we will compute its mean and variance.

mean_x, std_x = tf.nn.moments(x1, axes = 2, keep_dims=True)

We should notice \(axes = 2\), which means we will normaize input \(x1\). However, it is not batch normalization, it is layer normalization.

v1 = tf.nn.batch_normalization(x1, mean_x, std_x, None, None, 1e-12)
with tf.Session() as sess1:
    sess1.run(tf.global_variables_initializer())
    print(sess1.run(v1))

Run this code, we will get result:

[[[ 0.574993   -1.4064413   0.8314482 ]
  [-0.12501884 -1.1574404   1.2824591 ]]

 [[ 1.3801125  -0.95738953 -0.422723  ]
  [-0.5402142   1.4019756  -0.86176133]]

 [[-0.36398554  1.3654773  -1.0014919 ]
  [ 1.4136491  -0.67222667 -0.7414224 ]]

 [[-1.2645674   0.08396816  1.1806011 ]
  [-1.3146634   1.108713    0.20595042]]]

Leave a Reply