Hi,
I made a custom module containing N branches (N=out_channels) of Conv2d object with the same shape (in_channels, x, 1, padding=padding, bias=False)
. During forward, the input is inferred on each branch, and the result will be stacked in the final output (which will have 5 dimensions, (batch_size, x, h, w, N)).
The code is as following:
class MyModule(nn.Module):
"""
A convolutional module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int, optional): Size of the convolution kernel. Defaults to 3.
stride (int, optional): Stride. Defaults to 1.
padding (int, optional): Amount of padding to add to the input tensor. Defaults to 1.
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
Attributes:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of the weight tensor.
stride (int): Stride.
padding (int): Amount of padding to add to the input tensor.
"""
def __init__(self, in_channels, out_channels, x, kernel_size=3, stride=1, padding=1, bias=True):
super(MyModule, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.x = x
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.branches = nn.ModuleList([nn.Conv2d(in_channels, x, 1, padding=0, bias=bias) for _ in range(out_channels)])
def forward(self, input):
"""
Compute the output tensor given an input tensor.
Args:
input (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
Returns:
output (torch.Tensor): Output tensor of shape (batch_size, out_channels, height, width, x).
"""
output = torch.empty((input.shape[0], self.out_channels, input.shape[2], input.shape[3], self.x), dtype=input.dtype, device=input.device)
for i in range(self.out_channels):
output[:, i] = self.branches[i](input)
return output
The problem with my module is that it’s very inefficient because the out_channels can be very big. Using a for loop in the forward method should be avoided.
I’ve tried some recommendations on using torch.stack to stack the output of each branch to the final input but basically, it doesn’t improve the speed much.
Is there any cure to remedy this? I’m thinking of group convolution or conv3d.
Any suggestion is appreciated!
Thanks in advance!