When we are using torch.nn.Conv1d(), we may want the input and output have the same shape. In this tutorial, we will introduce you how to do.
torch.nn.Conv1d()
In order to use torch.nn.Conv1d() correctly, we can read this tutorial:
Understand torch.nn.Conv1d() with Examples – PyTorch Tutorial
From this tutorial, we can find:
The input is (N, C_in, L_in) and the output shape is (N, C_out, L_out).
In order to make the input and output have the same shape, we should make:
C_in = C_out, L_in = L_out
How to make the input and output have the same shape in Conv1d?
There are two factors to affect the shape of output.
(1) C_out, which is determined when initializing conv1d()
We can set C_in = C_out
(2) L_out is computed as follows:
In order to make L_out = L_in, if stride = 1
2*padding – dilation*(kernel_size -1) = 0
Here we will use an example to show you how to do.
import torch import torch.nn.functional as F import torch.nn as nn class ConvNorm(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=None, dilation=1, bias=True, w_init_gain='linear'): super(ConvNorm, self).__init__() if padding is None: assert (kernel_size % 2 == 1) padding = int(dilation * (kernel_size - 1) / 2) self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) torch.nn.init.xavier_uniform_( self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) def forward(self, signal): conv_signal = self.conv(signal) return conv_signal
In order to understand torch.nn.init.calculate_gain(), you can read:
Understand torch.nn.init.calculate_gain() with Examples – PyTorch Tutorial
In this code, we wrap torch.nn.Con1d() in a function, which makes the input and output have the same shape easily.
For example:
N = 32 C = 200 L = 50 input = torch.rand(N, C, L) conv1d = ConvNorm(C, C) output = conv1d(input) print(output.shape)
Run this code, we will get:
torch.Size([32, 200, 50])
When kernel_size = 3
conv1d = ConvNorm(C, C, kernel_size = 3) output = conv1d(input) print(output.shape)
The output shape also be:
torch.Size([32, 200, 50])
The input and output have the same shape.