Normal
is a batched univariate distribution. Your mu
is being broadcast up to the same shape as C
, producing an n,n
batch of univariate normals. If you want a MultivariateNormal
distribution, use
n = 5
mu = torch.zeros(n)
C = torch.eye(n, n)
m = torch.distributions.MultivariateNormal(mu, covariance_matrix=C)
x = m.sample() # should have shape (n,)
loss = -m.log_prob(x) # should be a scalar