How can I avoid for loop in forward method of a custom module?

Hi,
I made a custom module containing N branches (N=out_channels) of Conv2d object with the same shape (in_channels, x, 1, padding=padding, bias=False). During forward, the input is inferred on each branch, and the result will be stacked in the final output (which will have 5 dimensions, (batch_size, x, h, w, N)).
The code is as following:

class MyModule(nn.Module):
    """
    A convolutional module.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int, optional): Size of the convolution kernel. Defaults to 3.
        stride (int, optional): Stride. Defaults to 1.
        padding (int, optional): Amount of padding to add to the input tensor. Defaults to 1.
        bias (bool, optional): If True, adds a learnable bias to the output. Default: True

    Attributes:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int): Kernel size of the weight tensor.
        stride (int): Stride.
        padding (int): Amount of padding to add to the input tensor.

    """

    def __init__(self, in_channels, out_channels, x, kernel_size=3, stride=1, padding=1, bias=True):
        super(MyModule, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.x = x
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.branches = nn.ModuleList([nn.Conv2d(in_channels, x, 1, padding=0, bias=bias) for _ in range(out_channels)])

    def forward(self, input):
        """
        Compute the output tensor given an input tensor.

        Args:
            input (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).

        Returns:
            output (torch.Tensor): Output tensor of shape (batch_size, out_channels, height, width, x).

        """
        output = torch.empty((input.shape[0], self.out_channels, input.shape[2], input.shape[3], self.x), dtype=input.dtype, device=input.device)
        for i in range(self.out_channels):
            output[:, i] = self.branches[i](input)

        return output

The problem with my module is that it’s very inefficient because the out_channels can be very big. Using a for loop in the forward method should be avoided.
I’ve tried some recommendations on using torch.stack to stack the output of each branch to the final input but basically, it doesn’t improve the speed much.

Is there any cure to remedy this? I’m thinking of group convolution or conv3d.

Any suggestion is appreciated!

Thanks in advance!

Wouldn’t increasing the number of output channels work since you are already applying different conv filters to the full input activation?

Hi,
Actually, this’s similar to the inception module where each branch takes its own role. But with the inception module, normally it contains a fixed number of branches (e.g., 3 for 1x1, 3x3, and 5x5 kernel) while mine contains a very large number of branches (maybe 512).
Is there technically no cure for this?
Thanks!