I need to apply
logsumexp(X + Y) where
Y are upper triangular in (dim 0, 1… if 0 is rows 1 is columns and 2 is threads) and at at minimum size of
4000x4000x7000 (in the near future this will be much larger). I am struggling to speed this up as I am relatively new to PyTorch. Each dim 0, 1 slice of
X contains the exact same elements as
X was constructed via
np.tile("7k 1d array", (4000, 4000, 1)).
I am currently using the following
result = torch.logsumexp(torch.add(X, Y), dim=-1)
Any ideas on how to speed this up?