There is an implementation of logsumexp
in pytorch, but there is no logcumsumexp
. How can it be implemented efficiently and stable in pytorch?
Hello Artyom!
As it stands, it cannot, other than someone writing it.
You presumably understand the numerical issues with calculating
logsumexp
. (See, for example, the discussion in Wikipedia’s
LogSumExp.)
So if pytorch had, for example, a cummax
tensor function, you
could implement logcumsumexp
using pytorch tensor functions.
But this doesn’t exist (yet). See:
https://github.com/pytorch/pytorch/issues/20240
and
https://discuss.pytorch.org/t/sliding-max-over-dimension/49799
So, short of writing the logcumsumexp
(or related) tensor
function “from scratch,” you would have to use a loop to get
the “running maximum” (cummax
) part, thus forgoing some
of the efficiency provided by using just tensor functions.
Good luck.
K. Frank
@KFrank Thank you for your answer. I created PR, maybe someone will add cummax
and logcumsumexp
at ATen to be efficient. At the moment I end up with the following implementation, maybe someone will need it in the future.
import torch
import numpy as np
def cummax(x, dim):
x_np = x.detach().cpu().numpy()
ret = np.maximum.accumulate(x_np, axis=dim)
return torch.from_numpy(ret).to(x)
def logcumsumexp(x, dim=-1):
if (dim != -1) or (dim != x.ndimension() - 1):
x = x.transpose(dim, -1)
init_size = x.size()
last_dim_size = init_size[-1]
x_resized = x.contiguous().view(-1, last_dim_size)
d1, d2 = x_resized.size()
x_cummax = cummax(x_resized, -1).view(d1, d2, 1)
x_expand = x_resized.unsqueeze(1).expand(d1, d2, last_dim_size)
mask = torch.tril(torch.ones(last_dim_size, last_dim_size)).unsqueeze(0)
ret = torch.log(torch.sum(torch.exp(x_expand - x_cummax) * mask, dim=-1)) + x_cummax.view(d1, d2)
ret = ret.view(*init_size)
if (dim != -1) or (dim != x.ndimension() - 1):
ret = ret.transpose(-1, dim)
return ret