I think it’s because that logit.argmax() equals to softmax(logit).argmax() during inference.
logit.argmax()
softmax(logit).argmax()