Hi Jake!
When you instantiate d1
, you pass MultivariateNormal
the covariance
matrix (the diagonal elements of which are variances) that defines your
distribution. Normal
, in contrast, takes the standard deviations (which are
the square roots of the variances).
Try:
d1 = torch.distributions.MultivariateNormal (mean, torch.diag_embed (std**2))
and you should get the results you expect.
(The power of 2 that relates the variance to the standard deviation explains
the factor of two in your two computations of the gradient. The fact that you
got the same log_prob()
in both cases is an artifact of std
being a vector
of 1.0
s whose square is equal to itself.)
Best.
K. Frank