kibrq
(Kirill Brilliantov)
July 8, 2022, 8:29pm
1
Suppose, I have a variable x of a shape (L, N) and a following sampling operation:
softmaxed = softmax(x, dim=1)
sampled = torch.mutlinomial(softmaxed, k)
one_hot_encoded = torch.one_hot(sampled, N).to(torch.float64)
I want to make this sequence of operation diffirentiable . To do so one may say, that the derivative is approximately the same as derivative of softmax.
So I want to inherit the Function class and make a function sample
and I want to reuse softmax derivative calculation. I didn’t find a calculation of softmax derivative in pytorch. Any advice?
kibrq
(Kirill Brilliantov)
July 8, 2022, 9:18pm
2
I came up with the following solution:
input.shape = [L, N] where L is a length of a sequence and N is a number of symbols
class SoftmaxSample(Function):
@staticmethod
def forward(ctx, input, num_samples):
L, N = input.shape
with torch.enable_grad():
softmaxed = torch.softmax(input, dim = 1)
sampled = torch.multinomial(softmaxed, num_samples, replacement = True).transpose(1, 0)
output = F.one_hot(sampled, N).to(torch.float64)
ctx.save_for_backward(input, softmaxed)
return output
@staticmethod
def backward(ctx, grad_output):
input, softmaxed = ctx.saved_tensors
grad_input = grad_num_samples = None
grad_input, = torch.autograd.grad(softmaxed, input, grad_outputs=grad_output.mean(dim=0))
return grad_input, None
softmax_sampler = SoftmaxSampler.apply