Understand model.zero_grad() and optimizer.zero_grad() – PyTorch Tutorial

By | March 23, 2023

We usually create a training loops in PyTorch as follows:

optimizer = ...

for epoch in range(...):
    for i, sample in enumerate(dataloader):
        inputs, labels = sample
        optimizer.zero_grad()

	# Forward Pass
        outputs = model(inputs)
        # Compute Loss and Perform Back-propagation
	loss = loss_fn(outputs, labels)
        loss.backward()
	# Update Optimizer
        optimizer.step()

Here we use optimizer.zero_grad() to stop gradient.

However, we also find training loops like this:

optimizer = ...

for epoch in range(...):
    for i, sample in enumerate(dataloader):
        inputs, labels = sample
        model.zero_grad()

	# Forward Pass
        outputs = model(inputs)
        # Compute Loss and Perform Back-propagation
	loss = loss_fn(outputs, labels)
        loss.backward()
	# Update Optimizer
        optimizer.step()

Here we use model.zero_grad() to stop gradient. What is the difference between them?

Difference between model.zero_grad() and optimizer.zero_grad()

If we create optimizer by this mothod:

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=2e-5, amsgrad=True)

All parameters in model are passed into optimizer, model.zero_grad() = optimizer.zero_grad()

Otherwise, model.zero_grad() ≠ optimizer.zero_grad()

It means we can use model.zero_grad() and optimizer.zero_grad() in the most models.