It is very easy to compute the mean and variance of a tensor in tensorflow. In this tutorial, we will introduce how to calculate using tf.nn.moments() function.
If you want to learn how to compute variance and standard deviation in numpy, you can read:
Calculate Average, Variance, Standard Deviation of a Matrix in Numpy
What is variance?
You can find the definition of variance in this tutorial:
What is Sample Variance and How to Compute it in Numpy – Numpy Tutorial
In tensorflow, we can use tf.nn.moments() function.
Syntax
tf.nn.moments() is defined as:
tf.nn.moments( x, axes, shift=None, name=None, keep_dims=False )
It will calculate the mean and variance of x
You should notice:
- axes is not the axis
- keep_dims is not the keepdims
How to use tf.nn.moments()?
We will use some example to show you how to use it.
import numpy as np import tensorflow as tf xs = tf.convert_to_tensor(np.array([[[-1,3,2], [-3,1,3]],[[2,-7,4],[5,7, 6]]]), dtype = tf.float32) fc_mean, fc_var = tf.nn.moments(xs, axes = 2, keep_dims=True) init = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init) print (sess.run(fc_mean)) print (sess.run(fc_var))
In this example, xs is a (2,2, 3) tensor, we will compute the mean and variance of it on axes = 2.
Run this code, you will get the value:
[[[ 1.3333334 ] [ 0.33333334]] [[-0.33333334] [ 6. ]]] [[[ 2.8888893] [ 6.222223 ]] [[22.888887 ] [ 0.6666667]]]
The shape of mean and variance is [2,2,1], because we set keep_dims=True
How about axes is a list, for example axes = [1, 2]
fc_mean, fc_var = tf.nn.moments(xs, axes = [1, 2], keep_dims=True) init = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init) print (sess.run(fc_mean)) print (sess.run(fc_var))
You will get this value:
[[[0.8333333]] [[2.8333333]]] [[[ 4.805556]] [[21.805555]]]
if axes = 1
fc_mean, fc_var = tf.nn.moments(xs, axes = [1], keep_dims=True) init = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init) print (sess.run(fc_mean)) print (sess.run(fc_var))
You will get this value:
[[[-2. 2. 2.5]] [[ 3.5 0. 5. ]]] [[[ 1. 1. 0.25]] [[ 2.25 49. 1. ]]]
axes determines how to compute the mean and variance of x, you should notice the different feature for x.