We often use a pretrained pytorch model to initialize our own model. In this turorial, we will introduce you how to do.
Load an existing model
We can use torch.load() to load an existing pytorch model. Here is an example:
Save and Load Model in PyTorch: A Completed Guide – PyTorch Tutorial
Then, we can use this loaded model to initialize our own model.
model.state_dict()
Initializing our model, which means we will initialize weights or parameters in our model by a loaded model. We can use model.state_dict() to see all weights or parameters in a pytorch model. Here is the tutorial:
Understand PyTorch model.state_dict() – PyTorch Tutorial
How to initialize a pytorch model from a pretrained model
In order to initialize a pytorch model, we can use model.load_state_dict(). Here is an example.
First, we save a mode as a pretrained model.
import torch import torch.nn as nn import numpy as np class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(4, 4) self.fc1.weight.requires_grad = False self.fc2 = nn.Linear(4, 3) self.out = nn.Linear(3, 1) self.out_act = nn.Sigmoid() def forward(self, inputs): a1 = self.fc1(inputs) a2 = self.fc2(a1) a3 = self.out(a2) y = self.out_act(a3) return y model = Net() global device if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') checkpoint = 'test.pt' torch.save(model, checkpoint) # save all model object, not state_dict() print('loading model: {}...'.format(checkpoint))
Then, we can load it. Here we will create a function to load an existing model to initialize a new model.
Create a new pytorch model.
model_test = Net()
Create a function to initialize this new model.
def load_model(our_model, pretrained_model_path): if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') # load pretrained model pretrained_model = torch.load(pretrained_model_path, map_location=device) saved_state_dict = pretrained_model.state_dict() # get state_dict of our model state_dict = our_model.state_dict() new_state_dict = {} for k, v in state_dict.items(): try: new_state_dict[k] = saved_state_dict[k] except: print("%s is not in your model" % k) new_state_dict[k] = v model.load_state_dict(new_state_dict, strict=False) return model
Finally, we can initialize our new model by this function.
new_model = load_model(model_test, "test.pt") print(new_model)
Run this code, we will see:
loading model: test.pt... Net( (fc1): Linear(in_features=4, out_features=4, bias=True) (fc2): Linear(in_features=4, out_features=3, bias=True) (out): Linear(in_features=3, out_features=1, bias=True) (out_act): Sigmoid() )