Understand torch.nn.utils.clip_grad_norm_() with Examples: Clip Gradient – PyTorch Tutorial

By | July 19, 2022

When we are reading papers, we may see: All models are trained using Adam with a learning rate of 0.001 and gradient clipping at 2.0. In this tutorial, we will introduce gradient clipping in pytorch.

Gradient clipping in deep learning

In deep learning, we can use gradient clipping to solve gradient exploding problem.

How to use gradient clipping in pytorch?

In pytorch, we can use torch.nn.utils.clip_grad_norm_() to implement gradient clipping.

This function is defined as:

torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False)

It will clip gradient norm of an iterable of parameters.

Here

parameters: tensors that will have gradients normalized

max_norm: max norm of the gradients

As to gradient clipping at 2.0, which means max_norm = 2.0

It is easy to use torch.nn.utils.clip_grad_norm_(), we should place it between loss.backward() and optimizer.step()

Understand torch.nn.utils.clip_grad_norm_() with Examples

Here is an example:

        for i, data_batch in enumerate(data_loader):
            data_batch = [data.cuda() for data in data_batch[:-1]]

            bert_inputs, grid_labels, grid_mask2d, pieces2word, dist_inputs, sent_length = data_batch

            outputs = model(bert_inputs, grid_mask2d, dist_inputs, pieces2word, sent_length)

            grid_mask2d = grid_mask2d.clone()
            loss = self.criterion(outputs[grid_mask2d], grid_labels[grid_mask2d])

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), config.clip_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()

            loss_list.append(loss.cpu().item())

            outputs = torch.argmax(outputs, -1)
            grid_labels = grid_labels[grid_mask2d].contiguous().view(-1)
            outputs = outputs[grid_mask2d].contiguous().view(-1)

            label_result.append(grid_labels)
            pred_result.append(outputs)

            self.scheduler.step()

Here config.clip_grad_norm can be 2.0 or 5.0

Leave a Reply