When we are saving or loading a pytorch model, we may use model.state_dict(). Here is a tutorial:
Save and Load Model in PyTorch: A Completed Guide – PyTorch Tutorial
How to understand model.state_dict(). In this tutorial, we will use an example to explain it.
What is model.state_dict() in pytorch?
Look at this example:
import torch.nn as nn from torch.autograd import Variable import torch.optim as optim class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(2, 4) self.fc2 = nn.Linear(4, 3) self.out = nn.Linear(3, 1) self.out_act = nn.Sigmoid() def forward(self, inputs): a1 = self.fc1(inputs) a2 = self.fc2(a1) a3 = self.out(a2) y = self.out_act(a3) return y model_1 = Net()
Here we have created a model named model_1.
params = model_1.state_dict() print(params)
Here we will output model_1.state_dict()
We will get:
OrderedDict([('fc1.weight', tensor([[-0.6612, 0.0033], [-0.6802, 0.4862], [-0.0021, -0.0534], [-0.3389, 0.0287]])), ('fc1.bias', tensor([-0.2084, 0.2652, -0.6688, 0.3957])), ('fc2.weight', tensor([[-0.1868, -0.1317, -0.4074, -0.2077], [-0.3127, 0.4837, 0.0855, 0.3983], [ 0.3161, -0.2485, -0.4806, -0.1920]])), ('fc2.bias', tensor([-0.1653, -0.0355, 0.2690])), ('out.weight', tensor([[-0.4043, 0.1675, -0.2861]])), ('out.bias', tensor([0.2775]))])
From the result, we will get a OrderedDict, which contains parameter name and value in the model.
We can get all parameters in the pytorch model.
print(params.keys())
It will output:
odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'out.weight', 'out.bias'])
How to ouput parameter values by name?
We can do as follows:
print(params["fc1.weight"])
Here fc1.weight is the name of parameter.
We will see:
tensor([[-0.6612, 0.0033], [-0.6802, 0.4862], [-0.0021, -0.0534], [-0.3389, 0.0287]])