# Stable and efficient implementation of logcumsumexp

There is an implementation of `logsumexp` in pytorch, but there is no `logcumsumexp`. How can it be implemented efficiently and stable in pytorch?

Hello Artyom!

As it stands, it cannot, other than someone writing it.

You presumably understand the numerical issues with calculating
`logsumexp`. (See, for example, the discussion in Wikipedia’s
LogSumExp.)

So if pytorch had, for example, a `cummax` tensor function, you
could implement `logcumsumexp` using pytorch tensor functions.

But this doesn’t exist (yet). See:

https://stackoverflow.com/questions/55665624/vectorized-implementation-of-cumulative-maximum-in-pytorch-with-requires-grad-tr

https://github.com/pytorch/pytorch/issues/20240

and

https://discuss.pytorch.org/t/sliding-max-over-dimension/49799

So, short of writing the `logcumsumexp` (or related) tensor
function “from scratch,” you would have to use a loop to get
the “running maximum” (`cummax`) part, thus forgoing some
of the efficiency provided by using just tensor functions.

Good luck.

K. Frank

@KFrank Thank you for your answer. I created PR, maybe someone will add `cummax` and `logcumsumexp` at ATen to be efficient. At the moment I end up with the following implementation, maybe someone will need it in the future.

``````import torch
import numpy as np

def cummax(x, dim):
x_np = x.detach().cpu().numpy()
ret = np.maximum.accumulate(x_np, axis=dim)
return torch.from_numpy(ret).to(x)

def logcumsumexp(x, dim=-1):
if (dim != -1) or (dim != x.ndimension() - 1):
x = x.transpose(dim, -1)

init_size = x.size()
last_dim_size = init_size[-1]
x_resized = x.contiguous().view(-1, last_dim_size)
d1, d2 = x_resized.size()
x_cummax = cummax(x_resized, -1).view(d1, d2, 1)
x_expand = x_resized.unsqueeze(1).expand(d1, d2, last_dim_size)
mask = torch.tril(torch.ones(last_dim_size, last_dim_size)).unsqueeze(0)
ret = torch.log(torch.sum(torch.exp(x_expand - x_cummax) * mask, dim=-1)) + x_cummax.view(d1, d2)
ret = ret.view(*init_size)

if (dim != -1) or (dim != x.ndimension() - 1):
ret = ret.transpose(-1, dim)

return ret
``````