Argmax with PyTorch

(hard) argmax is not differentiable in general (this has nothing to do with PyTorch), i.e. one can not use gradient based methods with argmax. See e.g. https://www.reddit.com/r/MachineLearning/comments/4e2get/argmax_differentiable/ on how to train models involving argmax functions. One potential alternative suggested there is to use softmax instead.

10 Likes