I just discovered the truncated normal distribution and I’m trying to use torchrl TruncatedNormal but I can’t find anythin about “location” in the paper that introduced this idea and I don’t understand torchrl’s documentation :
loc ([torch.Tensor]) – normal distribution location parameter
Can someone explain more in detail what that parameter is ? Is it like the mean parameter of torch.normal ?
Yes, that’s my understanding as the location parameter should define where the distribution is positioned and in the case of a normal distribution the location should correspond to the mean. CC @KFrank to correct me as I’m sure you can properly define and explain it.
Yes, this is correct. loc is the same as mean (and scale is std). For the Normal
distribution (but not other distributions), it’s just a different name for the same thing.
It make sense to use a different name because, for example, the actual mean of an
(asymmetrically-truncated) TruncatedNormal distribution will not be loc – that is
the mean of the underlying Normal distribution – and the actual standard deviation
will not be scale – the std of the underlying Normal distribution.
torchrl.modules.TruncatedNormal follows torch.distributions.normal.Normal
for which:
loc (floatorTensor) – mean of the distribution (often referred to as mu)
(I believe thattorch.distributions.Normal uses loc and scale, rather than mean
and std, because there are other distributions in torch.distributions for which the
conventional loc and scale parameters aren’t equal to the mean and standard deviation.
On the other hand, torch.normal() calls the two parameters mean and std.)
I can confirm, we use loc because mean is the “first momentum” of a distribution, but not a parameter of that distribution per se. It just happens that for the normal distribution, E[X] = mu but it’s not true in general. Calling it mean in Truncated / Tanh versions of the Normal distribution wouldn’t be correct as the mean may not even have a closed form formula.