Argmax with PyTorch

How can I use argmax with PyTorch?


  • How could I have answered my own question? I searched the PyTorch docs and the PyTorch repo for “argmax” but got no results.
  • Does it make sense to use argmax with a GPU?

torch.max returns both the max values as well as the indices.
so you can do

values, indices = tensor.max(0)
values, indices = torch.max(tensor, 0)

I know that the values are differentiable (e.g., global max-pooling). Are the indices also differentiable? Thank you very much!


(hard) argmax is not differentiable in general (this has nothing to do with PyTorch), i.e. one can not use gradient based methods with argmax. See e.g. on how to train models involving argmax functions. One potential alternative suggested there is to use softmax instead.


Thank you very much!

1 Like

Any ideas for how to use this max function in a differentiable way? A custom loss function i’m writing has to do with the indices of max values. Not sure how to redo the loss function such that it uses differentiable components.

Hi @bgenchel,

I am not sure if I understand your problem. Do you want to have gradients with respect to indices? Well, indices are integers by definition and you cannot take derivatives of a function with respect to a variable that is defined over the integers only…

what does the 0 do in the indexing?

Yeah I found the zero to be confusing too. It’s the dimension along which you want to find the max. I was getting confused because in my case, the thing I wanted to find the max of had shape (1, 49), which meant when I did torch.max(preds, 0), I would just get back the whole array, and it didn’t make any sense. I needed to do torch.max(preds, 1), and indeed that returned (max value, index)

1 Like

you can now do torch.argmax(preds, dim=1) in version 0.4.0

@BlakeWest dimension 0 is the batch and dimension 1 is the class probabilities (assuming you use softmax on your final output). Therefore you would want to to do an argmax along dimension 1 ie. the class with the highest probabilities


hello. How can I manipulate the tensor by the indice, for example what should I do if I want to change the value of tensor elements that corresponding to the indice.

what about argmax from the tensor itself like:

import torch

x = torch.randn(3)

- where does one find the docs for this? googling takes you to the torch.argmax function and not the one for tensors…ans: seems this is where it is:

probably useful:

what do negative indices do? @fmassa