Understand PyTorch model.state_dict() – PyTorch Tutorial

By | April 12, 2023

When we are saving or loading a pytorch model, we may use model.state_dict(). Here is a tutorial:

Save and Load Model in PyTorch: A Completed Guide – PyTorch Tutorial

How to understand model.state_dict(). In this tutorial, we will use an example to explain it.

Understand PyTorch model.state_dict() - PyTorch Tutorial

What is model.state_dict() in pytorch?

Look at this example:

import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 4)
        self.fc2 = nn.Linear(4, 3)
        self.out = nn.Linear(3, 1)
        self.out_act = nn.Sigmoid()

    def forward(self, inputs):
        a1 = self.fc1(inputs)
        a2 = self.fc2(a1)
        a3 = self.out(a2)
        y = self.out_act(a3)
        return y

model_1 = Net()

Here we have created a model named model_1.

params = model_1.state_dict()
print(params)

Here we will output model_1.state_dict()

We will get:

OrderedDict([('fc1.weight', tensor([[-0.6612,  0.0033],
        [-0.6802,  0.4862],
        [-0.0021, -0.0534],
        [-0.3389,  0.0287]])), ('fc1.bias', tensor([-0.2084,  0.2652, -0.6688,  0.3957])), ('fc2.weight', tensor([[-0.1868, -0.1317, -0.4074, -0.2077],
        [-0.3127,  0.4837,  0.0855,  0.3983],
        [ 0.3161, -0.2485, -0.4806, -0.1920]])), ('fc2.bias', tensor([-0.1653, -0.0355,  0.2690])), ('out.weight', tensor([[-0.4043,  0.1675, -0.2861]])), ('out.bias', tensor([0.2775]))])

From the result, we will get a OrderedDict, which contains parameter name and value in the model.

We can get all parameters in the pytorch model.

print(params.keys())

It will output:

odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'out.weight', 'out.bias'])

How to ouput parameter values by name?

We can do as follows:

print(params["fc1.weight"])

Here fc1.weight is the name of parameter.

We will see:

tensor([[-0.6612,  0.0033],
        [-0.6802,  0.4862],
        [-0.0021, -0.0534],
        [-0.3389,  0.0287]])