- I’m looking at 2 different ways of backpropagating through the log-probability of samples from a Gaussian RV WRT to its parameters:
- With
torch.distributions.MultivariateNormalandtorch.diag_embed - With
torch.distributions.Normal, broadcasting and summing over the last dimension of the result
- With
- The calculated log-probability is the same in each case
- The gradient of the log probability WRT the mean is the same in each case
- The gradient of the log probability WRT the standard deviation gives different gradients (seemingly by a factor of 2), as shown below
- This is a MRE for the behaviour of a PPO implementation, which gives meaningfully different results according to how I define the standard deviation and calculate the log probability (
torch.distributions.MultivariateNormalgives significantly better results, which is unfortunate because I believe it is also less efficient) - Is this expected behaviour or is this a bug? I would have expected both approaches to be effectively doing the same thing and therefore to give the same answer
import torch
torch.manual_seed(0)
n = 10
mean = torch.zeros([n])
std = torch.ones([n], requires_grad=True)
a = torch.normal(mean, std)
print(a)
d1 = torch.distributions.MultivariateNormal(mean, torch.diag_embed(std))
p1 = d1.log_prob(a)
print(p1)
d2 = torch.distributions.Normal(mean, std)
p2 = d2.log_prob(a).sum(dim=-1)
print(p2)
p1.backward()
print(std.grad)
std.grad.zero_()
p2.backward()
print(std.grad)
Output:
tensor([ 1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986, 0.4033, 0.8380,
-0.7193, -0.4033], grad_fn=<NormalBackward3>)
tensor(-15.2935, grad_fn=<SubBackward0>)
tensor(-15.2935, grad_fn=<SumBackward1>)
tensor([ 0.6873, -0.4569, 1.8736, -0.3384, 0.0881, 0.4780, -0.4187, -0.1489,
-0.2413, -0.4187])
tensor([ 1.3747, -0.9139, 3.7471, -0.6769, 0.1762, 0.9561, -0.8373, -0.2977,
-0.4827, -0.8373])