Input to torch.distribution.categorical.Categorical()

If you don’t specify logits= in the creation of the distribution, the input will be used as probs by default, as it’s the first argument as seen in the docs.

Example:

x = torch.tensor([-1., -2., -1., 2.])
c = torch.distributions.categorical.Categorical(x)
c.sample() # error

c = torch.distributions.categorical.Categorical(logits=x)
c.sample() # works
1 Like