Another way to say it is that argmax() is not usefully differentiable.

Consider torch.argmax (torch.FloatTensor ([x, 1.0])). argmax() will be 1 for x < 1.0 and 0 for x > 1.0. In both cases
its derivative (gradient) with respect to x will be zero, and it wonâ€™t
be mathematically differentiable right at x = 1.0.

Having zero gradient almost everywhere isnâ€™t useful for gradient
descent optimization, so pytorch doesnâ€™t bother to implement
autograd (grad_fn) for argmax().