Speeding up a sum of products (that is not a matmul)

my application spends most of its time doing a particular tensor operation:

>>> a # shape (..., N, M)
>>> b # shape (N, M)
>>> (a*b).sum(dim=-1)

With einsum, the operation can be written as

>>> torch.einsum('...ij,ij->...i', (a,b))

Einsum is actually faster (maybe because it doesn’t need to allocate the temporary a*b?) in all of my tests, both on CPU and GPU…

Since this particular computation is the bottleneck in my application, I thought I’d check whether someone here knows how to speed it up even more – I’m not sure whether jitting would help (never used it)?

Thanks in advance,

EDIT: permutated versions of the tensors are free: I can produce these tensors with whatever shapes make computation faster.

Bump – I promise to not bump again

1 Like

Any luck speeding it up? I have a similar us case…