When we are training a pytorch model, we may want to freeze some layers or parameter. In this tutorial, we will introduce you how to freeze and train.
Look at this model below:
import torch.nn as nn from torch.autograd import Variable import torch.optim as optim class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(2, 4) self.fc2 = nn.Linear(4, 3) self.out = nn.Linear(3, 1) self.out_act = nn.Sigmoid() def forward(self, inputs): a1 = self.fc1(inputs) a2 = self.fc2(a1) a3 = self.out(a2) y = self.out_act(a3) return y model_1 = Net()
This code creates a model named model_1.
We can display all parameters in this model by model_1.state_dict()
params = model_1.state_dict() print(params.keys())
We will see:
odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'out.weight', 'out.bias'])
You can know more on pytorch model.state_dict() in this tutorial:
Understand PyTorch model.state_dict() – PyTorch Tutorial
Then we can freeze some layers or parameters as follows:
for name, para in model_1.named_parameters(): if name.startswith("fc1."): para.requires_grad = False
This code will freeze parameters that starts with “fc1.”
We can list all trainable parameters in pytorch model.
for name, para in model_1.named_parameters(): print(name, para.requires_grad)
List All Trainable Variables in PyTorch – PyTorch Tutorial
We will get:
fc1.weight False fc1.bias False fc2.weight True fc2.bias True out.weight True out.bias True
In order to train a model, we should create a optimizer for all trainable parameters.
Here is an example:
optimizer = optim.SGD(non_frozen_parameters, lr=0.1)
Then, we can start to train.