Understand PyTorch optimizer.step() with Examples – PyTorch Tutorial

By | May 5, 2022

When we are using pytorch to build our model and train, we have to use optimizer.step() method. In this tutorial, we will use some examples to help you understand it.

PyTorch optimizer.step()

Here optimizer is an instance of PyTorch Optimizer class. It is defined as:

Optimizer.step(closure)

It will perform a single optimization step (parameter update) and return a loss.

closure: (callable) – A closure that reevaluates the model and returns the loss.

For example:

As to torch.optim.Adam, it will overwrite step() method and return a loss.

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        ......

            F.adam(params_with_grad,
                   grads,
                   exp_avgs,
                   exp_avg_sqs,
                   max_exp_avg_sqs,
                   state_steps,
                   amsgrad=group['amsgrad'],
                   beta1=beta1,
                   beta2=beta2,
                   lr=group['lr'],
                   weight_decay=group['weight_decay'],
                   eps=group['eps'])
        return loss

How to use optimizer.step()?

optimizer.step() is usually used in train processing.

For example:

for input, target in dataset:
    optimizer.zero_grad() # step 1.
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward() # step 2
    optimizer.step() # step 3

It is usually used after loss.backward().

We also can use a closure callable function.

For example:

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

Leave a Reply