Understand torch.nn.utils.weight_norm() with Examples – PyTorch Tutorial

By | April 14, 2022

Weight normalization is proposed in paper: Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks. It can speed up convergence of stochastic gradient descent and can be applied successfully to recurrent models such as LSTMs and to noise-sensitive applications such as deep reinforcement learning or generative models, for which batch normalization is less well suited.

Weight normalization

It can be defined as:

\(w = \frac{g}{\parallel v \parallel}v\)

Here \(v\) is a k-dimensional vector, \(g\) is a scalar, and \(\parallel v \parallel\) denotes the euclidean norm of \(v\). By doing this, weight \(w\) can be \(\parallel w \parallel\ = g).

Instead of working with \(g\) directly, we can use an exponential parameterization for the scale, i.e. \(g = e^s\), where \(s\) is a log-scale parameter to learn by stochastic gradient descent. From some experiments, the eventual test-set performance was not significantly better or worse than the results with directly learning g in its original parameterization, and optimization was slightly slower.

How to implement weight normalization?

In pytorch, we can use torch.nn.utils.weight_norm() to implement it. It is defined as:

torch.nn.utils.weight_norm(module, name='weight', dim=0)

We should notice the parameter module, it is a pytorch module class. As to a weight in pytorch module, how weight normalization normalize it?

Here are some examples:

import torch
from torch.nn.utils import weight_norm

linear = torch.nn.Linear(5, 4,bias= False)
for name, param in linear.named_parameters():
    print(name, param)

linear_norm = weight_norm(linear, dim = 0)
print(linear_norm)
for name, param in linear_norm.named_parameters():
    print(name, param)

Run this code, we can find:

The original weight w in torch.nn.Linear is:

weight Parameter containing:
tensor([[ 0.2452, -0.0739, -0.4033,  0.0609,  0.0149],
        [ 0.2768,  0.1902, -0.4268,  0.4366,  0.2967],
        [ 0.1422, -0.4120,  0.2413, -0.0029,  0.1489],
        [ 0.2342,  0.0239,  0.0312, -0.1828, -0.3770]], requires_grad=True)

linear_norm is returned by weight_norm(), it is:

Linear(in_features=5, out_features=4, bias=False)

Which is same to linear.

In this example, the variable g is:

weight_g Parameter containing:
tensor([[0.4819],
        [0.7574],
        [0.5199],
        [0.4815]], requires_grad=True)

Here dim = 0, which mean we will normalize weight w in linear as follows:

how weight normalization normalize pytorch module

If dim = 1, the variable g is:

weight_g Parameter containing:
tensor([[0.4383, 0.4057, 0.5172, 0.6300, 0.7336]], requires_grad=True)

It means we will normalize weight w in linear as follows:

how weight normalization normalize pytorch module based on dimension

As to weight v in weight normalization, it is same to weight w in linear (\(v = w\)).

weight_v Parameter containing:
tensor([[-0.3479,  0.3607, -0.3595,  0.0582,  0.3540],
        [-0.0172,  0.1116,  0.1893,  0.4162,  0.3000],
        [-0.2173,  0.1215, -0.2943,  0.2343,  0.3633],
        [-0.1535,  0.0854,  0.1254, -0.4067, -0.4369]], requires_grad=True)

We can compute result using linear and linear_norm module.

inputs = torch.randn([3, 5])
x1 = linear(inputs)
x2 = linear_norm(inputs)
print("x1 = ", x1)
print("x2 = ", x2)

Run this code, we will get:

x1 =  tensor([[ 1.5303,  0.4063,  0.8504, -0.1195],
        [ 1.1122,  0.8465,  0.8764, -0.7010],
        [-0.7865, -1.3209, -0.8050,  0.6733]], grad_fn=<MmBackward0>)
x2 =  tensor([[ 1.5303,  0.4063,  0.8504, -0.1195],
        [ 1.1122,  0.8465,  0.8764, -0.7010],
        [-0.7865, -1.3209, -0.8050,  0.6733]], grad_fn=<MmBackward0>)

Here x1 = x2, it means we have not

Why v = w at the beginning?

If \(g == \parallel v \parallel\), \(v = w\), we will use an example to evalute it.

From example above, when dim = 1, \(g = [0.43830559 0.40572708 0.51711932 0.63000878 0.73361059]\)

weight w in linear module is:

[[-0.3479,  0.3607, -0.3595,  0.0582,  0.3540],
        [-0.0172,  0.1116,  0.1893,  0.4162,  0.3000],
        [-0.2173,  0.1215, -0.2943,  0.2343,  0.3633],
        [-0.1535,  0.0854,  0.1254, -0.4067, -0.4369]]

We will evalue \(g == \parallel v \parallel\) at the beginning.

import numpy as np

data = [[-0.3479,  0.3607, -0.3595,  0.0582,  0.3540],
        [-0.0172,  0.1116,  0.1893,  0.4162,  0.3000],
        [-0.2173,  0.1215, -0.2943,  0.2343,  0.3633],
        [-0.1535,  0.0854,  0.1254, -0.4067, -0.4369]]

data = np.array(data)

data = data* data

data = np.sum(data, axis= 0)
data = np.sqrt(data)
print(data)

g = [[0.4383, 0.4057, 0.5172, 0.6300, 0.7336]]

Run this code, we will find data is:

[0.43830559 0.40572708 0.51711932 0.63000878 0.73361059]

It means g = v = w at the beginning.

When dim = 0, we also can evalue g = v = w at the beginning. Here is an example:

import numpy as np

data = [[-0.2362, -0.3421,  0.4293,  0.0994,  0.1193],
        [-0.3963,  0.2133, -0.3640, -0.0877, -0.0750],
        [ 0.1906, -0.1868, -0.2297,  0.2133, -0.3716],
        [ 0.0871,  0.0613, -0.4081, -0.2980,  0.0644]]

data = np.array(data)

data = data* data

data = np.sum(data, axis= 1)
data = np.sqrt(data)
print(data)

g = [[0.6175],
        [0.5902],
        [0.5545],
        [0.5204]]

Run this code, data will be:

[0.61744165 0.59022273 0.55458826 0.52042393]

Leave a Reply