torch.nn.Linear() is one of the most popular function in pytorch, it is defined as:
torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
In order to use it, we can do as follows:
>>> m = nn.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30])
However, if you want to change the initialized method of its weight and bias, you have to create a wrapper class for it.
How to create a wrapper class for torch.nn.Linear()?
Here is an example code:
class Linear(nn.Module): """ Wrapper class of torch.nn.Linear Weight initialize by xavier initialization and bias initialize to zeros. """ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: super(Linear, self).__init__() self.linear = nn.Linear(in_features, out_features, bias=bias) init.xavier_uniform_(self.linear.weight) if bias: init.zeros_(self.linear.bias) def forward(self, x: Tensor) -> Tensor: return self.linear(x)
In this code, we will initialize linear.weight with xavier and linear.bias with zero.
We also can use other methods to initialize them, here is a tutorial:
Understand torch.nn.init.normal_() with Examples – PyTorch Tutorial