We often use pytorch to build a deep learning model, when training it, we may get NAN error. For example:
How to fix this NAN error?
In this tutorial, we will use a simple way to make you find where is wrong.
Step 1: Use torch.autograd.set_detect_anomaly(True) after import torch package.
For example:
import torch torch.autograd.set_detect_anomaly(True)
Step 2: use torch.autograd.detect_anomaly() for your loss function
For example:
# loss = model(X) with torch.autograd.detect_anomaly(): loss.backward()
Then, you may get this error:
We can find the torch.sqrt() function reports NAN error.
Find the torch.sqrt() function in your loss function, it may be:
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
We can change it to:
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(1e-12, 1))
Run this code, we will find this NAN error is fixed.