I found this expression in torch.einsum("btgi,gih->btgh", x, self.weight)
in https://github.com/Rikorose/DeepFilterNet/blob/7e4ef182511eabe9c3325292afce0feb5e612e4d/DeepFilterNet/df/modules.py#L774C9-L774C59 to implement GroupedLinear
layer.
Is it possible to somehow replace it by a broadcasted batched matmul call without a manual loop?
torch.matmul(x, self.weight[None, None, :, :, :])
?
Would it be faster than the einsum?