Initialize a PyTorch Model From a Pretrained Model – PyTorch Tutorial

By | September 27, 2023

We often use a pretrained pytorch model to initialize our own model. In this turorial, we will introduce you how to do.

Load an existing model

We can use torch.load() to load an existing pytorch model. Here is an example:

Save and Load Model in PyTorch: A Completed Guide – PyTorch Tutorial

Then, we can use this loaded model to initialize our own model.

model.state_dict()

Initializing our model, which means we will initialize weights or parameters in our model by a loaded model. We can use model.state_dict() to see all weights or parameters in a pytorch model. Here is the tutorial:

Understand PyTorch model.state_dict() – PyTorch Tutorial

How to initialize a pytorch model from a pretrained model

In order to initialize a pytorch model, we can use model.load_state_dict(). Here is an example.

First, we save a mode as a pretrained model.

import torch
import torch.nn as nn
import numpy as np

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 4)
        self.fc1.weight.requires_grad = False
        self.fc2 = nn.Linear(4, 3)
        self.out = nn.Linear(3, 1)
        self.out_act = nn.Sigmoid()

    def forward(self, inputs):
        a1 = self.fc1(inputs)
        a2 = self.fc2(a1)
        a3 = self.out(a2)
        y = self.out_act(a3)
        return y

model = Net()
global device
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
checkpoint = 'test.pt'
torch.save(model, checkpoint) # save all model object, not state_dict()
print('loading model: {}...'.format(checkpoint))

Then, we can load it. Here we will create a function to load an existing model to initialize a new model.

Create a new pytorch model.

model_test = Net()

Create a function to initialize this new model.

def load_model(our_model, pretrained_model_path):
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    # load pretrained model
    pretrained_model = torch.load(pretrained_model_path, map_location=device)
    saved_state_dict = pretrained_model.state_dict()

    # get state_dict of our model
    state_dict = our_model.state_dict()

    new_state_dict = {}
    for k, v in state_dict.items():
        try:
            new_state_dict[k] = saved_state_dict[k]
        except:
            print("%s is not in your model" % k)
            new_state_dict[k] = v
    model.load_state_dict(new_state_dict, strict=False)
    return model

Finally, we can initialize our new model by this function.

new_model = load_model(model_test, "test.pt")
print(new_model)

Run this code, we will see:

loading model: test.pt...
Net(
  (fc1): Linear(in_features=4, out_features=4, bias=True)
  (fc2): Linear(in_features=4, out_features=3, bias=True)
  (out): Linear(in_features=3, out_features=1, bias=True)
  (out_act): Sigmoid()
)