How can I use argmax with PyTorch?

Meta:

- How could I have answered my own question? I searched the PyTorch docs and the PyTorch repo for “argmax” but got no results.
- Does it make sense to use argmax with a GPU?

How can I use argmax with PyTorch?

Meta:

- How could I have answered my own question? I searched the PyTorch docs and the PyTorch repo for “argmax” but got no results.
- Does it make sense to use argmax with a GPU?

8 Likes

`torch.max`

returns both the max values as well as the indices.

so you can do

```
values, indices = tensor.max(0)
values, indices = torch.max(tensor, 0)
```

33 Likes

I know that the `values`

are differentiable (e.g., global max-pooling). Are the `indices`

also differentiable? Thank you very much!

3 Likes

(hard) `argmax`

is not differentiable in general (this has nothing to do with PyTorch), i.e. one can not use gradient based methods with argmax. See e.g. https://www.reddit.com/r/MachineLearning/comments/4e2get/argmax_differentiable/ on how to train models involving argmax functions. One potential alternative suggested there is to use softmax instead.

10 Likes

Thank you very much!

1 Like

Any ideas for how to use this max function in a differentiable way? A custom loss function i’m writing has to do with the indices of max values. Not sure how to redo the loss function such that it uses differentiable components.

Hi @bgenchel,

I am not sure if I understand your problem. Do you want to have gradients with respect to indices? Well, indices are integers by definition and you cannot take derivatives of a function with respect to a variable that is defined over the integers only…

what does the 0 do in the indexing?

Yeah I found the zero to be confusing too. It’s the dimension along which you want to find the max. I was getting confused because in my case, the thing I wanted to find the max of had shape (1, 49), which meant when I did `torch.max(preds, 0)`

, I would just get back the whole array, and it didn’t make any sense. I needed to do `torch.max(preds, 1)`

, and indeed that returned (max value, index)

1 Like

you can now do torch.argmax(preds, dim=1) in version 0.4.0

@BlakeWest dimension 0 is the batch and dimension 1 is the class probabilities (assuming you use softmax on your final output). Therefore you would want to to do an argmax along dimension 1 ie. the class with the highest probabilities

7 Likes

hello. How can I manipulate the `tensor`

by the `indice`

, for example what should I do if I want to change the value of `tensor`

elements that corresponding to the `indice`

.

what about argmax from the tensor itself like:

```
import torch
x = torch.randn(3)
x.argmax(-1)
```

meta:

- where does one find the docs for this? googling takes you to the `torch.argmax`

function and not the one for tensors…ans: seems this is where it is: https://pytorch.org/docs/stable/tensors.html

probably useful: