Implement Warm-up Scheduler in Pytorch – Pytorch Example

By | January 17, 2023

We have introduced some warm-up strategies in pytorch. Here are some tutorials:

Implement Cosine Annealing with Warm up in PyTorch – PyTorch Tutorial

Implement SGD Optimizer with Warm-up in PyTorch – PyTorch Tutorial

Change Learning Rate By Step When Training a PyTorch Model Initiatively – PyTorch Tutorial

In this tutorial, we will use an example to show you how to create a warm-up scheduler without package dependency.

Why using warm-up scheduler?

warm-up is a simple, yet effective way of solving the gradient problem in the first iterations when training a model. Optimizer, such as Adam, uses the bias correction factors which however can lead to a higher variance in the adaptive learning rate during the first iterations. Secondly, the iteratively applied Layer Normalization across layers can lead to very high gradients during the first iterations.

How to create a warm-up scheduler?

Here is an example:

import torch
import numpy as np
class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup, max_epoch, min_lr = 1e-9):
        self.warmup = warmup
        self.max_num_iters = max_epoch
        self.min_lr = min_lr
        super().__init__(optimizer)

    def get_lr(self):
        if self.last_epoch ==0 :
            return [self.min_lr]

        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        if epoch <= self.warmup:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

In this python class, warmup is how many epoch will be warm-uped. It usually can be set: warmup = 0.1*max_epoch.

max_epoch is the total epoch count you plan to train you model, for example, it can be 100.

We can use this warm-up scheduler as follows:

if __name__ == "__main__":
    from matplotlib import pyplot as plt

    lr_list = []
    model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
    LR = 0.001
    max_epoch = 200
    warmup_epochs = 0.1*max_epoch

    optimizer = torch.optim.Adam(model, lr=LR, weight_decay=2e-5)
    scheduler = CosineWarmupScheduler(optimizer, warmup_epochs, max_epoch, 1e-7)
    for epoch in range(max_epoch):
        data_size = 1000
        for i in range(data_size):
            optimizer.zero_grad()
            optimizer.step()
        lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
        scheduler.step()
    plt.plot(range(max_epoch), lr_list, color='r')
    plt.show()

In this example, we can find the learning rate is 1e-7 when epoch = 0.

Run this code, we will see:

Implement Warm-up Scheduler in Pytorch - Pytorch Example

We can find the learning rate is changed in different epoch.