Hi,
my application spends most of its time doing a particular tensor operation:
>>> a # shape (..., N, M)
>>> b # shape (N, M)
>>> (a*b).sum(dim=-1)
With einsum
, the operation can be written as
>>> torch.einsum('...ij,ij->...i', (a,b))
Einsum is actually faster (maybe because it doesn’t need to allocate the temporary a*b
?) in all of my tests, both on CPU and GPU…
Since this particular computation is the bottleneck in my application, I thought I’d check whether someone here knows how to speed it up even more – I’m not sure whether jitting would help (never used it)?
Thanks in advance,
Enrico
EDIT: permutated versions of the tensors are free: I can produce these tensors with whatever shapes make computation faster.