Using torch.argsort in CNN

Output of my CNN network is log probabilities. I want to get 10 best indices of the output and further use the indices to get loss. Hence, I have to use torch.argsort. But using torch.argsort breaks gradients.
Can anyone please help me on how do I proceed?
Thank you very much for the help

That’s expected as argsort is not differentiable and returns indices.
You could create a custom autograd.Function and implement the backward pass manually in case you want to “approximate” the gradient somehow.

Will I need to code for the entire backward pass?
That is will I have to code for backward pass of all the weights in CNN?

You would only need to implement the backward for argsort as it’s undefined.
Note that the backward pass is undefined as argsort is returning the indices and thus the operation is not differentiable as small eps changes of the input will not result in any contiguous changes in its output, so the more interesting question is how you would implement the gradient.