A Numerical Difference in log probability

I see a small numerical difference when finding the log probability with a Categorical distribution, but I can’t see why.

First, find the logits and sample an action:

logits = policy(state).log_softmax(-1)
m = torch.distributions.Categorical(logits=logits)
action = m.sample()

Now get the log probability of that action in two ways:

logProb = logits.squeeze()[action]
logProb = m.log_prob(action)

The two methods produce slightly different values after the 6th decimal place.
Shouldn’t these be exactly the same?

I’m not seeing the reason for the difference.

Hi Peter!

No. in general, the two results can differ by floating-point round-off error,
so this is expected.

Try redoing your test using all torch.float64 tensors (double precision).
You should still see a discrepancy, but now in something like the 15th
significant (decimal) digit.

Best.

K. Frank

Thanks, for your answer!

But, why am I wrong in assuming that if I pass in logits to torch.distributions.Categorical that no further computation needs to be done to extract the log probability?

I can see that if I passed in the output of a softmax then the distribution class needs to compute a log(), but in this case it’s already done.

Hi Peter!

If I understand your question correctly, I think that you’re missing the
fact that:

torch.distributions.Categorical (logits = logits)

and:

torch.distributions.Categorical (probs = torch.softmax (logits, dim = -1))

return essentially the same (up to floating-point round-off) distribution,
as illustrated by this example:

>>> torch.__version__
'1.7.1'
>>> logits = torch.randn (5)
>>> probs = torch.softmax (logits, dim = 0)
>>> dist_logits = torch.distributions.Categorical (logits = logits)
>>> dist_probs = torch.distributions.Categorical (probs = probs)
>>> dist_logits
Categorical(logits: torch.Size([5]))
>>> dist_logits.logits
tensor([-0.5839, -2.5019, -1.7824, -2.2692, -2.4226])
>>> dist_logits.probs
tensor([0.5577, 0.0819, 0.1682, 0.1034, 0.0887])
>>> dist_probs
Categorical(probs: torch.Size([5]))
>>> dist_probs.logits
tensor([-0.5839, -2.5019, -1.7824, -2.2692, -2.4226])
>>> dist_probs.probs
tensor([0.5577, 0.0819, 0.1682, 0.1034, 0.0887])

Best.

K. Frank

1 Like

Ah yes, thanks so much for clarifying that.

Best Regards,
– Peter