Custom Function with the same backward as Softmax

Suppose, I have a variable x of a shape (L, N) and a following sampling operation:

  • softmaxed = softmax(x, dim=1)
  • sampled = torch.mutlinomial(softmaxed, k)
  • one_hot_encoded = torch.one_hot(sampled, N).to(torch.float64)

I want to make this sequence of operation diffirentiable. To do so one may say, that the derivative is approximately the same as derivative of softmax.

So I want to inherit the Function class and make a function sample and I want to reuse softmax derivative calculation. I didn’t find a calculation of softmax derivative in pytorch. Any advice?

I came up with the following solution:

input.shape = [L, N] where L is a length of a sequence and N is a number of symbols

class SoftmaxSample(Function):

@staticmethod
def forward(ctx, input, num_samples):
    L, N = input.shape

    with torch.enable_grad():
        softmaxed = torch.softmax(input, dim = 1)
    sampled   = torch.multinomial(softmaxed, num_samples, replacement = True).transpose(1, 0)
    output    = F.one_hot(sampled, N).to(torch.float64)

    ctx.save_for_backward(input, softmaxed)
    return output

@staticmethod
def backward(ctx, grad_output):
    input, softmaxed = ctx.saved_tensors
    grad_input = grad_num_samples = None
    
    grad_input, = torch.autograd.grad(softmaxed, input, grad_outputs=grad_output.mean(dim=0))
    return grad_input, None

softmax_sampler = SoftmaxSampler.apply