Distribution Implementations

My team just open sourced a distributions library that we’re hoping to move upstream into PyTorch: https://github.com/uber/pyro/blob/dev/pyro/distributions

6 Likes

Thanks for your work, that’s really appreciated.

Am I correct in saying that pytorch’s Normal distribution is only for a spherical Gaussian, and that a multivariate normal has not been implemented?

You’re correct, torch.distributions.Normal is batched univariate. Brooks is working on a torch.distributions.MultivariateNormal https://github.com/probtorch/pytorch/pull/52 but this is blocked by incomplete support for gradients and batching in PyTorch’s many linear algebra operations. We have partial implementation of MultivariateNormal in Pyro https://github.com/uber/pyro/blob/c7c11e2/pyro/distributions/multivariate_normal.py but this does not support batched covariances and may have bugs in the gradient wrt the cholesky factorized covariance matrix.