Custom grad_fn when operation in forward pass detaches from computation graph

Hi, I’m trying to make a custom backward function for argsort. My actual code is a bit more complex, but this is basically what I want to run:

import torch

class TestSorter(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        y = torch.argsort(x)
        ctx.save_for_backward(x, y)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        return x * y


testsort = TestSorter.apply


if __name__ == '__main__':
    x = torch.tensor([5., 4, 3, 2, 1], requires_grad=True)
    y = testsort(x)
    y.sum().backward()
    print(y, x.grad)

This throws RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn. If I change y = torch.argsort(x) to y = 0 * x + torch.argsort(x) in the forward pass, it works, but this can hardly be the solution. How do I properly do this?

I know the fix is probably really simple, but I just started using PyTorch, was basically thrown in cold water. Any help would be appreciated! :slight_smile:

Hmm wouldn’t the derivative of argsort be zero everywhere (except the points where two elements are tied, in which it wouldn’t be differentiable)?

What is your use case for something like this? Why not just perform argsort on a tensor that does not require grad?

That is true. I’m trying to implement the result of a paper that describes a method to reasonably approximate a backward pass, so even if you have a discrete function like argsort I can define an alternative derivative through which it makes sense to backpropagate. I know the example doesn’t make a lot of sense like this, though. However, I figured there has to be a way to get this working. I don’t really care if this is the proper derivative of argsort.

Solution: change x = torch.tensor([5., 4, 3, 2, 1], requires_grad=True) to x = torch.tensor([5., 4, 3, 2, 1], requires_grad=True).float()