Standardize whether logits and probs get normalized internally in distribution functions?

Hi there,
I have been writing some wrappers for loss functions that use various probability distributions in pytorch. I’m noticing some inconsistencies in the meaning of a few arguments of the same name that have slowed me down, and unless I am missing something seems like a good area for cleanup. Here is one example I have found which highlights this point:




torch.distributions.Multinomail (which farms the initialization to torch.distributions.Categorical) normalizes both parameters on intialization:

            self.probs = probs / probs.sum(-1, keepdim=True)
            self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)

On the other hand torch.distributions.Binomial doesn’t do this:

        if probs is not None:
            self.total_count, self.probs, = broadcast_all(total_count, probs)
            self.total_count = self.total_count.type_as(self.probs)
            self.total_count, self.logits, = broadcast_all(total_count, logits)
            self.total_count = self.total_count.type_as(self.logits)

Would it create any issues if I do something similar to what is done with categorical in other distribution functions? It seems like if your input is properly normalized already, the above wouldn’t cause issues (unless probs.sum(-1) == 0), but likely there are already issues downstream if that is the case.

Thanks for letting me know if this is an agreeable change, and if there are already tickets/etc for it. If not it seems like a nice, small-ish thing to contribute, unless I am missing something about typical usages etc.