Unexpected behaviour with torch.distributions.utils.log_sum_exp()


#1

I’m using log_sum_exp, and I’ve noticed it seems to break when you use keepdim=False.

torch.distributions.utils.log_sum_exp(torch.ones([10, 100]), keepdim=True).shape
torch.Size([10, 1])
torch.distributions.utils.log_sum_exp(torch.ones([10, 100]), keepdim=False).shape
torch.Size([10, 10])

Shouldn’t it give torch.Size([10])? I’m working on larger dimensions and getting 5000x5000 tensors as output.

Or is there an alternative log-sum-exp I should be using?


#2

I think that’s a bug, but I can’t find documentation for torch.distributions.utils.log_sum_exp so I can’t confirm – do you know where the docs for that are?


#3

I couldn’t find any docs – I just happened to find it after looking through github to see if there was a native log-sum-exp function somewhere!


#4

On master there is torch.logsumexp https://github.com/pytorch/pytorch/pull/7254, but it hasn’t been documented yet. I’ll make a note that torch.distributions.utils.log_sum_exp appears to be strange.


#5

https://github.com/pytorch/pytorch/issues/8426 <- issue reported here.


#6

Great, thanks. I’ll try that out later.