Understand torch.nn.Module.modules() with Examples – PyTorch Tutorial

By | January 31, 2024

In this tutorial, we will use some examples to show how to use torch.nn.Module.modules().

Syntax

torch.nn.Module.modules() will return an iterator over all modules in the network.

How to use?

Here is an example:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

        #param
        self.weight = torch.nn.Parameter(torch.FloatTensor(200, 192), requires_grad=True)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

m = Model()
print(len(list(m.modules())))


for x in m.modules():
    print(x)

Run this code, we will see:

3
Model(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1))
)
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
Conv2d(20, 20, kernel_size=(5, 5), stride=(1, 1))

From the result, we can find:

  • A module and its child modules will be returned by calling .modules()
  • All parameters in one module will not be returned.

In this example, module m contains two child modules (self.conv1, self.conv2), m.modules() will return 3 modules.

We should notice:

Duplicate modules are returned only once

For example:

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
...     print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)

l will be returned only once.