Hi Peter!
If I understand your question correctly, I think that you’re missing the
fact that:
torch.distributions.Categorical (logits = logits)
and:
torch.distributions.Categorical (probs = torch.softmax (logits, dim = -1))
return essentially the same (up to floating-point round-off) distribution,
as illustrated by this example:
>>> torch.__version__
'1.7.1'
>>> logits = torch.randn (5)
>>> probs = torch.softmax (logits, dim = 0)
>>> dist_logits = torch.distributions.Categorical (logits = logits)
>>> dist_probs = torch.distributions.Categorical (probs = probs)
>>> dist_logits
Categorical(logits: torch.Size([5]))
>>> dist_logits.logits
tensor([-0.5839, -2.5019, -1.7824, -2.2692, -2.4226])
>>> dist_logits.probs
tensor([0.5577, 0.0819, 0.1682, 0.1034, 0.0887])
>>> dist_probs
Categorical(probs: torch.Size([5]))
>>> dist_probs.logits
tensor([-0.5839, -2.5019, -1.7824, -2.2692, -2.4226])
>>> dist_probs.probs
tensor([0.5577, 0.0819, 0.1682, 0.1034, 0.0887])
Best.
K. Frank