Masked LogSumExp

torch.logsumexp doesn’t provide support for masking. Is there an efficient way to implement it with masking?
I came up with the below function but I’m not sure if it’s the best way.

def masked_logsumexp(x, mask, eps=1e-8):
    mask_1 = mask.sum(dim=-1)
    m = torch.max(x*mask - ((1-mask)/self.eps), dim=-1, keepdim=True).values
    lse = m.squeeze(-1)*mask_1 + torch.log(torch.sum((torch.exp(x-m)*mask), dim=-1) + self.eps)
    return lse