Pass input data to independent modules of nn.ModuleList concurrently


I have a simple module list like as follows:

self.Linear = nn.ModuleList()
for i in range(self.channels):
    self.Linear.append(nn.Linear(self.seq_len, self.pred_len))

In the forward pass, I pass the input data to this module list sequentially, which is not efficient. It is worth mentioning that these modules are independent of each other, and ideally we should be able to pass the data to them concurrently. Is there a way to do that?

Here is what I have in the forward method:

for i in range(self.channels):
    output[:,:,i] = self.Linear[i](x[:,:,i])

@ptrblck can you help me with this?

You could check vmap’s ensemble approach described in this tutorial.

1 Like