my application spends most of its time doing a particular tensor operation:
>>> a # shape (..., N, M) >>> b # shape (N, M) >>> (a*b).sum(dim=-1)
einsum, the operation can be written as
>>> torch.einsum('...ij,ij->...i', (a,b))
Einsum is actually faster (maybe because it doesn’t need to allocate the temporary
a*b?) in all of my tests, both on CPU and GPU…
Since this particular computation is the bottleneck in my application, I thought I’d check whether someone here knows how to speed it up even more – I’m not sure whether jitting would help (never used it)?
Thanks in advance,
EDIT: permutated versions of the tensors are free: I can produce these tensors with whatever shapes make computation faster.