Hi, I’m trying to make a custom backward function for argsort. My actual code is a bit more complex, but this is basically what I want to run:
import torch class TestSorter(torch.autograd.Function): @staticmethod def forward(ctx, x): y = torch.argsort(x) ctx.save_for_backward(x, y) return y @staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors return x * y testsort = TestSorter.apply if __name__ == '__main__': x = torch.tensor([5., 4, 3, 2, 1], requires_grad=True) y = testsort(x) y.sum().backward() print(y, x.grad)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn. If I change
y = torch.argsort(x) to
y = 0 * x + torch.argsort(x) in the forward pass, it works, but this can hardly be the solution. How do I properly do this?
I know the fix is probably really simple, but I just started using PyTorch, was basically thrown in cold water. Any help would be appreciated!