Strange behavior of log_prob when calculating gradient

I have the following model:

from torch.distributions import Normal

class Policy(nn.Module):
    def __init__(self, num_inp, num_out, hidden_dim):
        super().__init__()
        
        self.linear1 = nn.Linear(num_inp, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        
        self.mean_linear = nn.Linear(hidden_dim, num_out)
        self.log_std_linear = nn.Linear(hidden_dim, num_out)

    def forward(self, obs):
        x = F.relu(self.linear1(obs))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)
        log_std = torch.clamp(self.log_std_linear(x), min=-2, max=20)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()
        log_prob = normal.log_prob(x_t)
        log_prob = log_prob.sum()
        return x_t, log_prob

the gradient with respect to the parameters can be calculated as:

import torch.autograd as autograd

model = Policy(17, 6, 50)
inp = torch.rand(1, 17)
x_t, log_prob = model(inp)

grads = autograd.grad(log_prob, model.parameters())

However, the gradient with respect to mean_linear (weight and bias) is 0; all other gradients seem correct (linear1, linear2, log_std_linear). I have absolutely no idea why this is the case. When checking the source code of log_prob`` I can see it calls both, self.loc(the mean) andself.scale(the std), therefore the gradient should propagate, but somehow the output is always0```. Why does this happen? and how to fix it?

Hi cijerezg!

When you call .log_prob (x_t) using the same Normal distribution from
which you .rsample()ed x_t, the mean drops out, so you are correctly
getting zero for your gradient.

What’s going on is that when you use “reparameterization sampling,” the
Distribution generates some random variate (I assume a uniform variate,
but I don’t really know.) and transforms it into a variate from your desired
distribution. This transformation is differentiable with respect to the
parameters that describe your distribution, so you can backpropagate
through the rsample()ed value.

In the case of Normal, that random variate tells you how improbable
your sample should be – how many standard deviations it should be
from the mean of your Normal distribution. log_prob() then tells you
how improbable your sample is – that is, how many standard deviations
it is from that same mean. So log_prob() is independent of mean, and
your gradient is zero.

Consider:

>>> import torch
>>> torch.__version__
'1.12.0'
>>> normal10 = torch.distributions.Normal (10.0, 1.234)
>>> normal100 = torch.distributions.Normal (100.0, 1.234)
>>> _ = torch.manual_seed (2022)
>>> samp10 = normal10.rsample ([5])
>>> _ = torch.manual_seed (2022)
>>> samp100 = normal100.rsample ([5])
>>> samp10
tensor([10.2363, 10.4080, 10.2846, 11.1027,  9.7478])
>>> samp100
tensor([100.2363, 100.4080, 100.2846, 101.1027,  99.7478])
>>> normal10.log_prob (samp10)
tensor([-1.1475, -1.1839, -1.1558, -1.5285, -1.1501])
>>> normal100.log_prob (samp100)
tensor([-1.1475, -1.1839, -1.1558, -1.5285, -1.1501])

Best.

K. Frank

Thanks for the insightful answer. This gave me another question, in case you know. When using RL one differentiates the log of the policy with respect to the parameters, e.g., the mean and the variance. Arguably, having the correct mean is more important than having the correct variance, e.g., some RL algorithms assume a variance and learn the mean.

How can one go about getting the gradient of the mean? Does it have to be done manually?