My team just open sourced a distributions library that we’re hoping to move upstream into PyTorch: https://github.com/uber/pyro/blob/dev/pyro/distributions
Thanks for your work, that’s really appreciated.
Am I correct in saying that pytorch’s Normal distribution is only for a spherical Gaussian, and that a multivariate normal has not been implemented?
torch.distributions.Normal is batched univariate. Brooks is working on a
torch.distributions.MultivariateNormal https://github.com/probtorch/pytorch/pull/52 but this is blocked by incomplete support for gradients and batching in PyTorch’s many linear algebra operations. We have partial implementation of
MultivariateNormal in Pyro https://github.com/uber/pyro/blob/c7c11e2/pyro/distributions/multivariate_normal.py but this does not support batched covariances and may have bugs in the gradient wrt the cholesky factorized covariance matrix.