Strange behavior of log_prob when calculating gradient

Hi cijerezg!

When you call .log_prob (x_t) using the same Normal distribution from
which you .rsample()ed x_t, the mean drops out, so you are correctly
getting zero for your gradient.

What’s going on is that when you use “reparameterization sampling,” the
Distribution generates some random variate (I assume a uniform variate,
but I don’t really know.) and transforms it into a variate from your desired
distribution. This transformation is differentiable with respect to the
parameters that describe your distribution, so you can backpropagate
through the rsample()ed value.

In the case of Normal, that random variate tells you how improbable
your sample should be – how many standard deviations it should be
from the mean of your Normal distribution. log_prob() then tells you
how improbable your sample is – that is, how many standard deviations
it is from that same mean. So log_prob() is independent of mean, and
your gradient is zero.

Consider:

>>> import torch
>>> torch.__version__
'1.12.0'
>>> normal10 = torch.distributions.Normal (10.0, 1.234)
>>> normal100 = torch.distributions.Normal (100.0, 1.234)
>>> _ = torch.manual_seed (2022)
>>> samp10 = normal10.rsample ([5])
>>> _ = torch.manual_seed (2022)
>>> samp100 = normal100.rsample ([5])
>>> samp10
tensor([10.2363, 10.4080, 10.2846, 11.1027,  9.7478])
>>> samp100
tensor([100.2363, 100.4080, 100.2846, 101.1027,  99.7478])
>>> normal10.log_prob (samp10)
tensor([-1.1475, -1.1839, -1.1558, -1.5285, -1.1501])
>>> normal100.log_prob (samp100)
tensor([-1.1475, -1.1839, -1.1558, -1.5285, -1.1501])

Best.

K. Frank