How to go from probabilities to sampling and backpropagate?

Hello, I’m implementing a network which as an intermediate procedure computes probabilities of each element of a sequence to be picked (let’s say a tensor of 8x1, from a sequence of 8). I need to take a multinomial sampling from it (for example picking 2 of those elements) and set the tensor as all 0’s excepting the elements from the sampling. I implemented it as the following:

        multinomial_out = out.transpose(0, 1).multinomial(2, False)

        for i in range(out.size()[0]):
            out[i, 0] = 1 if i in multinomial_out.view(2) else 0

        out = out.view(size_seq)

        return out

Where out is the mentioned tensor.

However, in that way, gradients from the parameters of the previous components of the network are stored as 0 and thus it is not able to backpropagate. How can I perform this operation allowing backpropagation?

Thank you in advanced.

You cannot mathematically differentiate that.
The typical “pretend” solution is to have the backward act as if you passed the probability vector, e.g.

pdist = torch.randn(5,4).softmax(1).requires_grad_()
out = torch.zeros_like(pdist)
out.scatter_(1, pdist.multinomial(2), torch.ones(()))
out_with_soft_backward = pdist + (out - pdist).detach()

Best regards