Possible to replace einsum expr "btgi,gih->btgh" by any matmul call?

I found this expression in torch.einsum("btgi,gih->btgh", x, self.weight) in https://github.com/Rikorose/DeepFilterNet/blob/7e4ef182511eabe9c3325292afce0feb5e612e4d/DeepFilterNet/df/modules.py#L774C9-L774C59 to implement GroupedLinear layer.

Is it possible to somehow replace it by a broadcasted batched matmul call without a manual loop?

torch.matmul(x, self.weight[None, None, :, :, :])?

Would it be faster than the einsum?

Hmm, seems torch.matmul broadcasting somehow does not work, although docs say it should be working…

import torch

x = torch.rand(10, 20, 30, 40)
weight = torch.rand(30, 40, 50)

print(torch.einsum("btgi,gih->btgh", x, weight).shape)
# torch.Size([10, 20, 30, 50])

print(torch.matmul(x, weight[None, None, ...]).shape)
# RuntimeError: The size of tensor a (20) must match the size of tensor b (30) at non-singleton dimension 2

@Lezcano answered in torch.matmul copies instead of broadcasting, while F.linear doesn't · Issue #76702 · pytorch/pytorch · GitHub ! Basically, this is batched vector-matrix product, so batch dims are computed differently