I found this expression in
torch.einsum("btgi,gih->btgh", x, self.weight) in https://github.com/Rikorose/DeepFilterNet/blob/7e4ef182511eabe9c3325292afce0feb5e612e4d/DeepFilterNet/df/modules.py#L774C9-L774C59 to implement
Is it possible to somehow replace it by a broadcasted batched matmul call without a manual loop?
torch.matmul(x, self.weight[None, None, :, :, :])?
Would it be faster than the einsum?