Parallel execution of modules in nn.ModuleList

Hi, I have a ModuleList contaning many modules.
I use the modules in the ModuleList as modules composing the same layer, so in the forward function I iterate through the list and than concatenate the results of the single modules.
(Basically I simultate the old torch nn.Parallel).

Is there a way to perform it efficiently (in parallel) ? It is very costly when dealing with big ModulesList.



Did you find the answer to this question? I am having the same problem - would like to execute a ModuleList in parallel.

Hi Olivier,

no, I didn’t yet.

I asked a similar question on StackOverflow and got an interesting answer:

Haven’t had time to try it out yet but looks promising.

Actually I also asked the same question here recently but no answers: Running multiple Modules in parallel

I roughly read it, but it seems a bit tricky and constraints to used conv layers.

I tried the answer in the Stack Overflow and it seems to be faster with a GPU and is slower without a GPU:

import torch
from torch import nn
import numpy as n

class MultiHeadParallel(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_size=32, nb_heads=1):
        super().__init__() = nn.Sequential(
            nn.Conv1d(in_channels=input_dim * nb_heads, out_channels=hidden_size * nb_heads, kernel_size=1, groups=nb_heads),
            nn.Conv1d(in_channels=hidden_size * nb_heads, out_channels=output_dim * nb_heads, kernel_size=1, groups=nb_heads),
        self.nb_heads = nb_heads
    def forward(self, x):
      x = x.repeat(1, self.nb_heads).unsqueeze(-1)
      flat =
      batch_size = x.shape[0]
      return flat.view(batch_size, -1, self.nb_heads)

class MultiHeadNaive(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_size=32, nb_heads=1):

        self.networks = nn.ModuleList()
        for _ in range(nb_heads):
            network = nn.Sequential(
                nn.Linear(input_dim, hidden_size),
                nn.Linear(hidden_size, output_dim),

    def forward(self, x):
      outputs = [net(x) for net in self.networks]
      return torch.stack(outputs, dim=-1)

IN_DIM = 256
OUT_DIM = 256
NB_HEADS = 1000

net_parallel = MultiHeadParallel(

net_naive = MultiHeadNaive(
x = torch.randn(BATCH_SIZE, IN_DIM)

print("***Without GPU***")
%time y_naive = net_naive(x)
print("\nWith Conv1D")
%time y_parallel = net_parallel(x)

print("\n***With GPU***")
x = x.cuda()
%time y_naive = net_naive(x); torch.cuda.synchronize()
print("\nWith Conv1D")
%time y_parallel = net_parallel(x); torch.cuda.synchronize()


***Without GPU***
CPU times: user 176 ms, sys: 0 ns, total: 176 ms
Wall time: 176 ms

With Conv1D
CPU times: user 629 ms, sys: 17 µs, total: 629 ms
Wall time: 633 ms

***With GPU***
CPU times: user 11.2 ms, sys: 3.05 ms, total: 14.2 ms
Wall time: 14.1 ms

With Conv1D
CPU times: user 4.07 ms, sys: 971 µs, total: 5.04 ms
Wall time: 4.86 ms

Edit: added the torch.cuda.synchronize() call and new timing


Since CUDA operations are asynchronous, could you add torch.cuda.synchronize() to your profiling code and rerun it, please?

Done. The results are basically the same. (I’ve edited my original post).


Did you test the back propagation time for these two different networks?

Did you find a good method to deal with the problem? Use group conv is constrained.

I’ve recently been running up against this problem for a weird transformer I’m working on. I ended up just coding a custom linear layer with a channel dimension, though I’m dealing with some instability with it (I think due to initialization but not sure, part of why I’m posting this).

My solution was as follows:

class MultichannelLinear(nn.Module):
    def __init__(self, channels, in_features, out_features):
        super(MultichannelLinear, self).__init__()
        self.weight_pw = nn.Parameter(torch.empty(channels, out_features, in_features))
        nn.init.uniform_(self.weight_pw, a=-1/math.sqrt(in_features*channels), b=1/math.sqrt(in_features*channels))

    def __call__(self, x):
        x = torch.matmul(x.transpose(2,3), self.weight_pw.transpose(1,2)).transpose(2,3)
        return x

It does extremely well for a little bit but eventually this multichannel linear layer seems to cause NaNs in its output (checking each modules output lead me to find this matmul call is the culprit). I had previously been initializing via nn.init.uniform_(self.weight_pw, a=-1/math.sqrt(in_features), b=1/math.sqrt(in_features)) and am currently testing out the above, but I’m not totally sure.

Edit: I actually think the instability is stemming from elsewhere in my architecture so who knows maybe this will work for someone or someone can point out a flaw in this for me lol.

Edit 2: Last edit lol. It seems I just needed more normalization, including with the query and key tensors as in 2010.04245.pdf ( In addition, I added layer norm in between the two feedforward layers in the transformer architecture. Without normalization and no LR warmup it became unstable in 17 seconds whereas with the layer norm it was able to train for 20 minutes or so (wanted to initiate a full run so cancelled that run)