Implement SGD Optimizer with Warm-up in PyTorch – PyTorch Tutorial

By | October 25, 2022

When we are using SGD optimizer to train a pytorch model, we may use warm-up strategy to improve the training efficiency. In this tutorial, we will introduce you how to implement this strategy in pytorch.

Preliminary

We can use pytorch-gradual-warmup-lr, you can download it here.

https://github.com/LvJC/pytorch-gradual-warmup-lr

Implement SGD optimizer with warm-up

Here is an example code:

import torch
from warmup_scheduler import GradualWarmupScheduler

from matplotlib import pyplot as plt
lr_list = []
model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
LR = 0.001

optimizer = torch.optim.SGD(model, lr=LR, momentum=0.9, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=0.97)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=6, after_scheduler=scheduler)


for epoch in range(1,200):
    data_size = 400
    for i in range(data_size):
        optimizer.zero_grad()
        #loss
        optimizer.step()
    scheduler_warmup.step(epoch)
    #
    lr_list.append(scheduler_warmup.get_lr())

print(lr_list)
plt.plot(range(len(lr_list)),lr_list,color = 'r')
plt.show()

Run this code, we will see the learning rate as follows:

Implement SGD Optimizer wih Warm-up in PyTorch - PyTorch Tutorial

However, we start to train with epoch = 1, the first learning rate is 0.00016666666666666666

If epoch = 0, the first learning rate will be 0

For example:

for epoch in range(0,200):
    data_size = 400
    for i in range(data_size):
        optimizer.zero_grad()
        #loss
        optimizer.step()
    scheduler_warmup.step(epoch)
    #
    lr_list.append(scheduler_warmup.get_lr())

print(lr_list)
plt.plot(range(len(lr_list)),lr_list,color = 'r')
plt.show()

We will see:

Implement SGD Optimizer wih Warm-up epoch = 0 in PyTorch - PyTorch Tutorial

In order to get a correct warm-up learning rate, we shoud start to train with epoch = 1.

Leave a Reply