Categorical distributions and LogSoftmax

The question concerns the torch.distributions implementation. This is the canonical example from the relase page,

probs = policy_network(state)
# NOTE: categorical is equivalent to what used to be called multinomial
m = torch.distributions.Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

Usually, the probabilities are obtained from policy_network as a result of a softmax operation.
However, if we subsequently take a logarithm of the softmax probabilities, are not we loosing the numerical precision?
And if so, why the Categorical distribution is not implemented in the way that allows object construction with only logits provided?

3 Likes

Usually, the probabilities are obtained from policy_network as a result of a softmax operation.

you should put a LogSoftmax in the policy network’s last layer, not Softmax, precisely for the reason you described.

you should put a LogSoftmax in the policy network’s last layer, not Softmax

If I do so, I will be required to subsequently call exp(log_probs) to match a signature of the torch.distributions.Categorical.
Also, the subsequent call to torch.distributions.Categorica(probs).log_prob(0) will also take the log, which, again, will cause all the numerical issues.
So, either I don’t understand the purpose of the current implementation of torch.distributions or the example on the github release page is misleading.

Also, probably something is weird with sampling from Categorical. After executing

probs = F.softmax(torch.autograd.Variable(torch.Tensor([.25, .6])))
dist = torch.distributions.Categorical(probs)

this code works find

torch.multinomial(dist.probs, 2, True)

but this one does not

dist.sample_n(10)

yielding the following stack trace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-792c4214dc6f> in <module>()
----> 1 dist.sample_n(10)

~/anaconda3/lib/python3.6/site-packages/torch/distributions.py in sample_n(self, n)
    143             return self.sample().expand(1, 1)
    144         else:
--> 145             return torch.multinomial(self.probs, n, True).t()
    146 
    147     def log_prob(self, value):

RuntimeError: invalid argument 2: out of range at /Users/soumith/minicondabuild3/conda-bld/pytorch_1512381214802/work/torch/lib/TH/generic/THTensor.c:479

This happens because of the last .t().

1 Like

@fritzo does it make sense for us to have a keyword argument logits for Categorical like TensorFlow?

Yes, it makes sense, and we already do that in Pyro: https://github.com/uber/pyro/blob/dev/pyro/distributions/categorical.py#L34

1 Like

Hi @ipaulo Did you find the answer to your question? I am interested too.

Hi, I think the log_prob comes from the policy optimization algorithm (i.e., convert multiply to sum as here). The output of the policy network should still be a distribution over your action space.