How to speed up small parallel nn.Linear blocks (or small parallel matrix multiplications or group convolutions)?

Not interested in your money if this is what you want, but I just posted something which sounds like it might be useful for you here: Parallel execution of modules in nn.ModuleList - PyTorch Forums

I’m doing parallel position-wise linear layers where each parallel channel learns its own fully connected layer. I have different ‘channels’ for each so they don’t have the same input, but you can just do what I said in that topic if you want to apply it to a single ‘channel.’ Still doing testing around it though, stability has been a bit of a problem but I think I’ve gotten it stable with a fair amount of normalization. (edit 2: after adding layer norm to more areas before or after the multichannel linear layers I was able to stabilize my multichannel transformer)

Edit: realize I didn’t post the edit I was thinking of to that thread. But this code is a quick toy example:

m = MultichannelLinear(4, 8, 8)
b = torch.ones((1, 8, 16)) # B,H,W
b = b.unsqueeze(1).expand((1, 4, 8, 16)) # B,C,H,W
c = torch.mean(m(b), dim=1) # B,H,W

Could of course use a depthwise convolution to compress to a single channel and then squeeze that as well depending on your use case.