How to speedup logsumpexp on upper triangular 3 way tensors?


#1

I need to apply logsumexp(X + Y) where X and Y are upper triangular in (dim 0, 1… if 0 is rows 1 is columns and 2 is threads) and at at minimum size of 4000x4000x7000 (in the near future this will be much larger). I am struggling to speed this up as I am relatively new to PyTorch. Each dim 0, 1 slice of X contains the exact same elements as X was constructed via np.tile("7k 1d array", (4000, 4000, 1)).

I am currently using the following

result = torch.logsumexp(torch.add(X, Y), dim=-1)

Any ideas on how to speed this up?