Another way to say it is that argmax() is not usefully differentiable.
Consider torch.argmax (torch.FloatTensor ([x, 1.0])). argmax() will be 1 for x < 1.0 and 0 for x > 1.0. In both cases
its derivative (gradient) with respect to x will be zero, and it won’t
be mathematically differentiable right at x = 1.0.
Having zero gradient almost everywhere isn’t useful for gradient
descent optimization, so pytorch doesn’t bother to implement
autograd (grad_fn) for argmax().