Output of my CNN network is log probabilities. I want to get 10 best indices of the output and further use the indices to get loss. Hence, I have to use torch.argsort. But using torch.argsort breaks gradients.

Can anyone please help me on how do I proceed?

Thank you very much for the help

That’s expected as `argsort`

is not differentiable and returns indices.

You could create a custom `autograd.Function`

and implement the `backward`

pass manually in case you want to “approximate” the gradient somehow.

Will I need to code for the entire backward pass?

That is will I have to code for backward pass of all the weights in CNN?

You would only need to implement the `backward`

for `argsort`

as it’s undefined.

Note that the `backward`

pass is undefined as `argsort`

is returning the indices and thus the operation is not differentiable as small `eps`

changes of the input will not result in any contiguous changes in its output, so the more interesting question is how you would implement the gradient.