I need to apply logsumexp(X + Y)
where X
and Y
are upper triangular in (dim 0, 1… if 0 is rows 1 is columns and 2 is threads) and at at minimum size of 4000x4000x7000
(in the near future this will be much larger). I am struggling to speed this up as I am relatively new to PyTorch. Each dim 0, 1 slice of X
contains the exact same elements as X
was constructed via np.tile("7k 1d array", (4000, 4000, 1))
.
I am currently using the following
result = torch.logsumexp(torch.add(X, Y), dim=-1)
Any ideas on how to speed this up?