For the Normal distribution, one can use rsample() allowing computation of the gradient through a sample. This does not exist for a Categorical distribution.
How can we do the same?
mu, sigma = neural_net(input)
sample = Normal(mu, sigma).rsample()
I would like to do something like:
probs = neural_net(input)
sample = Categorical(probs).rsample()
The problem is that the samples from the categorical distribution are discrete, so there is no gradient to compute. This is in contrast to the Gaussian where you can write
X = Z * sigma + mu with
Z ~ N(0,1) to get a
N(mu, sigma)-distributed variable (the reparametrization trick in some circles).
A common way around this is to not sample, but compute the loss for all categories, weighted with their respective probabilities. This then allows to compute gradients w.r.t. the probabilities.
That was clear, thanks! The solution is to use the Gumbel-Softmax and to make sure the gradients remain for the category probabilities, I did the following (dot product mm):
z, probs = g_s_net(input, temp)
sample = torch.mm(z, torch.FloatTensor(some_categorical_list))
loss = probs * z
Here z is the categorical one hot tensor and multiplying it with the probs will give the probs for the corresponding class.