Understand PyTorch Module forward() Function – PyTorch Tutorial

By | December 17, 2021

When we are building a pytorch module, we need create a forward() function. For example:

understand pytorch module forward() function

In this example code, Backbone is a pytorch module, we implement a forward() function in it.

However, when forward() function is called?

In example above, you may find this code:

embedding = self.backbone(x)

This code will call forward() in Backbone, why?

You find this answer in pytorch module source code.

https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module

From this source code, we can find:

understand pytorch module forward() function explain

torch.nn.moduel class implement __call__ function, it will call _call_impl(), if we do not create a forward hook, self.forward() function will be called.

__call__ can make a torch.nn.module instance be callable, you can find this answer in here.

Python Make a Class Instance Callable Like a Function – Python Tutorial

As to this code:

embedding = self.backbone(x)

self.backbone is a Backbone instance, it will call __call__() function and forward() function will be called.

That is the secret of pytorch module forward() funciton.

Leave a Reply