Tutorial Example

An Simple Guide to Fix Torch NAN Error – PyTorch Tutorial

We often use pytorch to build a deep learning model, when training it, we may get NAN error. For example:

How to fix this NAN error?

In this tutorial, we will use a simple way to make you find where is wrong.

Step 1: Use torch.autograd.set_detect_anomaly(True) after import torch package.

For example:

import torch
torch.autograd.set_detect_anomaly(True)

Step 2: use torch.autograd.detect_anomaly() for your loss function

For example:

# loss = model(X)
with torch.autograd.detect_anomaly():
    loss.backward()

Then, you may get this error:

We can find the torch.sqrt() function reports NAN error.

Find the torch.sqrt() function in your loss function, it may be:

sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))

We can change it to:

sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(1e-12, 1))

Run this code, we will find this NAN error is fixed.