Stable and efficient implementation of logcumsumexp

There is an implementation of logsumexp in pytorch, but there is no logcumsumexp. How can it be implemented efficiently and stable in pytorch?

Hello Artyom!

As it stands, it cannot, other than someone writing it.

You presumably understand the numerical issues with calculating
logsumexp. (See, for example, the discussion in Wikipedia’s
LogSumExp.)

So if pytorch had, for example, a cummax tensor function, you
could implement logcumsumexp using pytorch tensor functions.

But this doesn’t exist (yet). See:

https://stackoverflow.com/questions/55665624/vectorized-implementation-of-cumulative-maximum-in-pytorch-with-requires-grad-tr

https://github.com/pytorch/pytorch/issues/20240

and

https://discuss.pytorch.org/t/sliding-max-over-dimension/49799

So, short of writing the logcumsumexp (or related) tensor
function “from scratch,” you would have to use a loop to get
the “running maximum” (cummax) part, thus forgoing some
of the efficiency provided by using just tensor functions.

Good luck.

K. Frank

@KFrank Thank you for your answer. I created PR, maybe someone will add cummax and logcumsumexp at ATen to be efficient. At the moment I end up with the following implementation, maybe someone will need it in the future.

import torch
import numpy as np

def cummax(x, dim):
    x_np = x.detach().cpu().numpy()
    ret = np.maximum.accumulate(x_np, axis=dim)
    return torch.from_numpy(ret).to(x)


def logcumsumexp(x, dim=-1):
    if (dim != -1) or (dim != x.ndimension() - 1):
        x = x.transpose(dim, -1)

    init_size = x.size()
    last_dim_size = init_size[-1]
    x_resized = x.contiguous().view(-1, last_dim_size)
    d1, d2 = x_resized.size()
    x_cummax = cummax(x_resized, -1).view(d1, d2, 1)
    x_expand = x_resized.unsqueeze(1).expand(d1, d2, last_dim_size)
    mask = torch.tril(torch.ones(last_dim_size, last_dim_size)).unsqueeze(0)
    ret = torch.log(torch.sum(torch.exp(x_expand - x_cummax) * mask, dim=-1)) + x_cummax.view(d1, d2)
    ret = ret.view(*init_size)

    if (dim != -1) or (dim != x.ndimension() - 1):
        ret = ret.transpose(-1, dim)

    return ret