Hi, I’m trying to make a custom backward function for argsort. My actual code is a bit more complex, but this is basically what I want to run:
import torch
class TestSorter(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = torch.argsort(x)
ctx.save_for_backward(x, y)
return y
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
return x * y
testsort = TestSorter.apply
if __name__ == '__main__':
x = torch.tensor([5., 4, 3, 2, 1], requires_grad=True)
y = testsort(x)
y.sum().backward()
print(y, x.grad)
This throws RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
. If I change y = torch.argsort(x)
to y = 0 * x + torch.argsort(x)
in the forward pass, it works, but this can hardly be the solution. How do I properly do this?
I know the fix is probably really simple, but I just started using PyTorch, was basically thrown in cold water. Any help would be appreciated!