I’m using log_sum_exp, and I’ve noticed it seems to break when you use keepdim=False.

```
torch.distributions.utils.log_sum_exp(torch.ones([10, 100]), keepdim=True).shape
torch.Size([10, 1])
torch.distributions.utils.log_sum_exp(torch.ones([10, 100]), keepdim=False).shape
torch.Size([10, 10])
```

Shouldn’t it give `torch.Size([10])`

? I’m working on larger dimensions and getting 5000x5000 tensors as output.

Or is there an alternative log-sum-exp I should be using?