Backward for negative log likelihood loss of MultivariateNormal (in distributions)

Normal is a batched univariate distribution. Your mu is being broadcast up to the same shape as C, producing an n,n batch of univariate normals. If you want a MultivariateNormal distribution, use

n = 5
mu = torch.zeros(n)
C = torch.eye(n, n)
m = torch.distributions.MultivariateNormal(mu, covariance_matrix=C)
x = m.sample()  # should have shape (n,)
loss = -m.log_prob(x)  # should be a scalar
1 Like