Argmax with PyTorch

How can I use argmax with PyTorch?

Meta:

  • How could I have answered my own question? I searched the PyTorch docs and the PyTorch repo for “argmax” but got no results.
  • Does it make sense to use argmax with a GPU?
8 Likes

torch.max returns both the max values as well as the indices.
so you can do

values, indices = tensor.max(0)
values, indices = torch.max(tensor, 0)
32 Likes

I know that the values are differentiable (e.g., global max-pooling). Are the indices also differentiable? Thank you very much!

3 Likes

(hard) argmax is not differentiable in general (this has nothing to do with PyTorch), i.e. one can not use gradient based methods with argmax. See e.g. https://www.reddit.com/r/MachineLearning/comments/4e2get/argmax_differentiable/ on how to train models involving argmax functions. One potential alternative suggested there is to use softmax instead.

10 Likes

Thank you very much!

1 Like

Any ideas for how to use this max function in a differentiable way? A custom loss function i’m writing has to do with the indices of max values. Not sure how to redo the loss function such that it uses differentiable components.

Hi @bgenchel,

I am not sure if I understand your problem. Do you want to have gradients with respect to indices? Well, indices are integers by definition and you cannot take derivatives of a function with respect to a variable that is defined over the integers only…

what does the 0 do in the indexing?

Yeah I found the zero to be confusing too. It’s the dimension along which you want to find the max. I was getting confused because in my case, the thing I wanted to find the max of had shape (1, 49), which meant when I did torch.max(preds, 0), I would just get back the whole array, and it didn’t make any sense. I needed to do torch.max(preds, 1), and indeed that returned (max value, index)

1 Like

you can now do torch.argmax(preds, dim=1) in version 0.4.0

@BlakeWest dimension 0 is the batch and dimension 1 is the class probabilities (assuming you use softmax on your final output). Therefore you would want to to do an argmax along dimension 1 ie. the class with the highest probabilities

7 Likes

hello. How can I manipulate the tensor by the indice, for example what should I do if I want to change the value of tensor elements that corresponding to the indice.

what about argmax from the tensor itself like:

import torch

x = torch.randn(3)
x.argmax(-1)

meta:
- where does one find the docs for this? googling takes you to the torch.argmax function and not the one for tensors…ans: seems this is where it is: https://pytorch.org/docs/stable/tensors.html


probably useful:

what do negative indices do? @fmassa