Implement Wrapper Class of torch.nn.Linear() – PyTorch Tutorial

By | March 22, 2023

torch.nn.Linear() is one of the most popular function in pytorch, it is defined as:

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

In order to use it, we can do as follows:

>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

However, if you want to change the initialized method of its weight and bias, you have to create a wrapper class for it.

How to create a wrapper class for torch.nn.Linear()?

Here is an example code:

class Linear(nn.Module):
    """
    Wrapper class of torch.nn.Linear
    Weight initialize by xavier initialization and bias initialize to zeros.
    """
    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super(Linear, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        init.xavier_uniform_(self.linear.weight)
        if bias:
            init.zeros_(self.linear.bias)

    def forward(self, x: Tensor) -> Tensor:
        return self.linear(x)

In this code, we will initialize linear.weight with xavier and linear.bias with zero.

We also can use other methods to initialize them, here is a tutorial:

Understand torch.nn.init.xavier_uniform_() and torch.nn.init.xavier_normal_() with Examples – PyTorch Tutorial

Understand torch.nn.init.normal_() with Examples – PyTorch Tutorial