We may create some parameters in pytorch torch.nn.Module. However, do you know to to initialize it? In this tutorial, we will discuss this topic.
Create a torch.nn.Parameter variable
It is easy to create a torch.nn.Parameter variable in pytorch. For example:
weight = torch.nn.Parameter(torch.randn(5,5)) print(weight)
Run this code, we will see:
Parameter containing: tensor([[ 0.0265, -0.0283, 1.4823, -0.2538, 0.4521], [ 1.0145, 0.3092, 0.9226, -0.3336, 1.0320], [ 0.3867, -0.2104, -0.1408, -0.6255, -0.3301], [ 0.9972, -0.9325, -0.3661, -1.9633, -0.3348], [ 0.6539, -0.3673, -0.3820, -0.5793, 0.6751]], requires_grad=True)
Initialize torch.nn.Parameter variable with different methods
There are some methods that can initialize torch.nn.Parameter variable.
For example:
import torch weight = torch.nn.Parameter(torch.Tensor(5, 5)) print(weight)
Here we have created a 5*5 empty tensor. It is:
Parameter containing: tensor([[8.4490e-39, 1.1112e-38, 1.0194e-38, 9.0919e-39, 8.7245e-39], [1.1112e-38, 4.2245e-39, 8.4489e-39, 9.6429e-39, 8.4490e-39], [9.6429e-39, 9.2755e-39, 1.0286e-38, 9.0919e-39, 8.9082e-39], [9.2755e-39, 8.4490e-39, 1.0194e-38, 9.0919e-39, 8.4490e-39], [1.0745e-38, 1.0653e-38, 1.0286e-38, 1.0194e-38, 9.2755e-39]], requires_grad=True)
In order to initialize it, we can:
Use tensor built-in function
weight.data.uniform_(-1, 1) print(weight)
In this code, we use tensor.uniform_() to initialize the value of weight.
Then we will get:
Parameter containing: tensor([[-0.1689, 0.4635, -0.6464, -0.9490, -0.6724], [-0.6435, 0.9929, 0.4329, 0.8690, 0.9690], [ 0.4082, -0.2569, -0.2214, -0.9469, -0.5764], [ 0.3985, 0.3738, 0.9745, -0.2259, 0.1704], [-0.4165, 0.7454, -0.6311, -0.5067, 0.8072]], requires_grad=True)
Use torch.nn.init function
For example:
torch.nn.init.xavier_uniform_( weight, gain=torch.nn.init.calculate_gain("linear")) print(weight)
Here we use torch.nn.init.xavier_uniform_() to initialize the weight.
Then we will get:
Parameter containing: tensor([[ 0.2747, -0.5777, 0.5408, -0.4074, 0.5775], [ 0.1661, 0.7140, 0.0584, 0.4028, -0.4579], [-0.5976, -0.3443, 0.5204, -0.2844, -0.5127], [-0.6429, 0.2265, -0.4873, -0.7025, -0.4240], [-0.0778, 0.4199, -0.1536, -0.2102, -0.4628]], requires_grad=True)