Hi,

my application spends most of its time doing a particular tensor operation:

```
>>> a # shape (..., N, M)
>>> b # shape (N, M)
>>> (a*b).sum(dim=-1)
```

With `einsum`

, the operation can be written as

```
>>> torch.einsum('...ij,ij->...i', (a,b))
```

Einsum is actually faster (maybe because it doesn’t need to allocate the temporary `a*b`

?) in all of my tests, both on CPU and GPU…

Since this particular computation is *the* bottleneck in my application, I thought I’d check whether someone here knows how to speed it up even more – I’m not sure whether jitting would help (never used it)?

Thanks in advance,

Enrico

EDIT: permutated versions of the tensors are free: I can produce these tensors with whatever shapes make computation faster.