Different gradients from `torch.distributions.MultivariateNormal`/`torch.diag_embed` vs `torch.distributions.Normal`

  • I’m looking at 2 different ways of backpropagating through the log-probability of samples from a Gaussian RV WRT to its parameters:
    • With torch.distributions.MultivariateNormal and torch.diag_embed
    • With torch.distributions.Normal, broadcasting and summing over the last dimension of the result
  • The calculated log-probability is the same in each case
  • The gradient of the log probability WRT the mean is the same in each case
  • The gradient of the log probability WRT the standard deviation gives different gradients (seemingly by a factor of 2), as shown below
  • This is a MRE for the behaviour of a PPO implementation, which gives meaningfully different results according to how I define the standard deviation and calculate the log probability (torch.distributions.MultivariateNormal gives significantly better results, which is unfortunate because I believe it is also less efficient)
  • Is this expected behaviour or is this a bug? I would have expected both approaches to be effectively doing the same thing and therefore to give the same answer
import torch
torch.manual_seed(0)
n = 10
mean = torch.zeros([n])
std = torch.ones([n], requires_grad=True)
a = torch.normal(mean, std)
print(a)
d1 = torch.distributions.MultivariateNormal(mean, torch.diag_embed(std))
p1 = d1.log_prob(a)
print(p1)
d2 = torch.distributions.Normal(mean, std)
p2 = d2.log_prob(a).sum(dim=-1)
print(p2)
p1.backward()
print(std.grad)
std.grad.zero_()
p2.backward()
print(std.grad)

Output:

tensor([ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845, -1.3986,  0.4033,  0.8380,
        -0.7193, -0.4033], grad_fn=<NormalBackward3>)
tensor(-15.2935, grad_fn=<SubBackward0>)
tensor(-15.2935, grad_fn=<SumBackward1>)
tensor([ 0.6873, -0.4569,  1.8736, -0.3384,  0.0881,  0.4780, -0.4187, -0.1489,
        -0.2413, -0.4187])
tensor([ 1.3747, -0.9139,  3.7471, -0.6769,  0.1762,  0.9561, -0.8373, -0.2977,
        -0.4827, -0.8373])

Hi Jake!

When you instantiate d1, you pass MultivariateNormal the covariance
matrix (the diagonal elements of which are variances) that defines your
distribution. Normal, in contrast, takes the standard deviations (which are
the square roots of the variances).

Try:

d1 = torch.distributions.MultivariateNormal (mean, torch.diag_embed (std**2))

and you should get the results you expect.

(The power of 2 that relates the variance to the standard deviation explains
the factor of two in your two computations of the gradient. The fact that you
got the same log_prob() in both cases is an artifact of std being a vector
of 1.0s whose square is equal to itself.)

Best.

K. Frank

1 Like

Hi @KFrank ,

Thanks for explaining, makes 100% perfect sense :slight_smile:

This works. Equivalently, I can apply torch.sqrt before passing std (which should be renamed to EG to action_var) to torch.distributions.Normal, keep everything else the same, and get the same results, even when sampling from the 2 distributions (as long as I reset the random seed before each sample):

import torch
torch.manual_seed(0)
n = 10
mean = torch.zeros([n])
std = torch.ones([n], requires_grad=True)
a = torch.normal(mean, std)
print(a)
d1 = torch.distributions.MultivariateNormal(mean, torch.diag_embed(std))
p1 = d1.log_prob(a)
print(p1)
d2 = torch.distributions.Normal(mean, torch.sqrt(std))
p2 = d2.log_prob(a).sum(dim=-1)
print(p2)
p1.backward()
print(std.grad)
std.grad.zero_()
p2.backward()
print(std.grad)
torch.manual_seed(0)
print(d1.sample())
torch.manual_seed(0)
print(d2.sample())

Output:

tensor([ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845, -1.3986,  0.4033,  0.8380,
        -0.7193, -0.4033], grad_fn=<NormalBackward3>)
tensor(-15.2935, grad_fn=<SubBackward0>)
tensor(-15.2935, grad_fn=<SumBackward1>)
tensor([ 0.6873, -0.4569,  1.8736, -0.3384,  0.0881,  0.4780, -0.4187, -0.1489,
        -0.2413, -0.4187])
tensor([ 0.6873, -0.4569,  1.8736, -0.3384,  0.0881,  0.4780, -0.4187, -0.1489,
        -0.2413, -0.4187])
tensor([ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845, -1.3986,  0.4033,  0.8380,
        -0.7193, -0.4033])
tensor([ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845, -1.3986,  0.4033,  0.8380,
        -0.7193, -0.4033])

Ideally I’d like to use torch.distributions.Normal, sqrt, and log_prob.sum(dim=-1) instead of torch.distributions.MultivariateNormal and diag_embed in my PPO implementation, because the former should be more efficient for high-dimensional actions. The above demo suggests both implementations should give the same results, although weirdly I get different results from the 2 different implementations (given the same random seed in each experiment). I don’t understand why this is the case, although I guess it’s a bit off-topic from my original question.

Hi Jake!

This is most likely not a problem.

It would not be surprising for the two implementation to give results that
differ by numerical round-off error (even if mathematically equivalent).

Round-off error can accumulate, so, depending on how complicated a single
iteration is, you might see the results of the two methods differ by, say, ten
times typical round-off error after a single iteration. As you continue to iterate,
the two computations may well start to diverge. This may or may not be a
problem; if it is a problem, the problem isn’t the round-off error, but rather
a sign that your problem is unstable in some way.

If your results differ by a lot more than typical round-off error after a single
iteration, you likely have a real problem that you should track down.

A good way to test whether your discrepancy is due to round-off error is to
repeat your computation in double-precision and check that the discrepancy
drops by several orders of magnitude.

Best.

K. Frank

1 Like