How to speed up small parallel nn.Linear blocks (or small parallel matrix multiplications or group convolutions)?

These three are essentially the same thing and unfortunately they’re very slow in PyTorch. Group convolutions with stride=1 and kernel_size=1 were my main hope for a fast implementation but they’re especially slow when the group size (channels / groups) is around 100.

I tried @torch.jit.script on a loop-based implementation but it doesn’t help. Any pointers about what to try next? I’m reading about torchdynamo, NVFuser, TVM, Halide, Triton, JAX integration - would you suggest trying them? Something else? Or just drop down to C++/CUDA and try to make a PyTorch extension?

Hi @ipostr08,

Have you look at functorch? You might be able to use vmap to speed-up your use-case?

Thanks for the answer @AlphaBetaGamma96, I haven’t tried it yet.

By the way, can I put a bounty on this question? $1000 by PayPal if somebody can give me a ready solution or $200 if somebody points me to something I can implement without too much hassle and works with my setup? I think I’m unnecessarily spending money on too many cloud GPUs and wasting electricity to boot.

Not interested in your money if this is what you want, but I just posted something which sounds like it might be useful for you here: Parallel execution of modules in nn.ModuleList - PyTorch Forums

I’m doing parallel position-wise linear layers where each parallel channel learns its own fully connected layer. I have different ‘channels’ for each so they don’t have the same input, but you can just do what I said in that topic if you want to apply it to a single ‘channel.’ Still doing testing around it though, stability has been a bit of a problem but I think I’ve gotten it stable with a fair amount of normalization. (edit 2: after adding layer norm to more areas before or after the multichannel linear layers I was able to stabilize my multichannel transformer)

Edit: realize I didn’t post the edit I was thinking of to that thread. But this code is a quick toy example:

m = MultichannelLinear(4, 8, 8)
b = torch.ones((1, 8, 16)) # B,H,W
b = b.unsqueeze(1).expand((1, 4, 8, 16)) # B,C,H,W
c = torch.mean(m(b), dim=1) # B,H,W

Could of course use a depthwise convolution to compress to a single channel and then squeeze that as well depending on your use case.

Also, share a minimal reproducible example so people can debug your problem!

Something simple like this:

class SlowGroups(nn.Module):
   def __init__(self, channels : int, groups : int):
      super().__init__()
      assert(channels % groups == 0)
      self.groups = groups
      self.projs = nn.ModuleList()
      for n in range(groups):
         self.projs.append(nn.Linear(channels // groups, channels // groups))

   def forward(self, x):
      parts = torch.chunk(x.transpose(-1, -2), self.groups, dim=-1)
      return torch.cat([self.projs[n](parts[n]) for n in range(self.groups)], dim=-1).transpose(-1, -2)

I’ve done benchmarking comparing it to Conv1d with kernel_size=1, groups=100, padding=0 and Conv1d is somewhat faster (but still slow) for very small group sizes and actually slower in training for group sizes near 100!

Performance increase doesn’t seem huge when using equal input and output features in the linear layers, but profiling my solution on my gpu shows it is a bit faster than conv1d (on a 3080 ti with mixed precision) - in my own use case where I expand the features for each channel/group before compressing them again in a residual block, it halves training time vs conv1d. If I replace my module with conv1d it doubles the training time as opposed to when using batched matrix multiplication with a single combined weight matrix. Instability I mentioned was due to missing normalization in a few places via accidentally feeding in a wrong variable. Takes a bit of reshaping but seems to be far more performant depending on use case.

For a simple toy example with equal input and output as in your example above, conv1d with groups gets 1924ms while the batched matrix multiplication comes in at around 1263ms.

My module:

class MultichannelLinear(nn.Module):
    def __init__(self, channels, in_features, out_features):
        super(MultichannelLinear, self).__init__()
        
        self.weight_pw = nn.Parameter(torch.empty(channels, out_features, in_features))
        nn.init.uniform_(self.weight_pw, a=-1/math.sqrt(in_features*channels), b=1/math.sqrt(in_features*channels))

    def __call__(self, x):
        x = torch.matmul(x.transpose(2,3), self.weight_pw.transpose(1,2)).transpose(2,3)
        
        return x

Update: using functorch’s vmap helps, though the difference is not spectacular. It can even be a little faster than group Conv1d in training for some sizes but only if a different tensor memory layout than Conv1d’s default is used. Conv1d is still faster in inference. Tested on a 3090.