 # Use Argmax to index tensor

How can I use the argmax values to index a tensor?

So, for example, I have two tensors of the same shape x,y and have the argmax = x.min(-1) of one of them. Then I want to get the values at the position in y i.e. y[argmax] ?

How can I do that ?

``````argmax = x.max(0)
``````

since `max` returns a tuple.

Then you can use it as you did:

``````y[argmax]
``````
``````x = torch.rand(3,4)
y = torch.rand(3,4)
argmax = x.max(0)
y[argmax]
``````

Does not work for me.

You probably want to use arange(4) as a second index or use gather or so.

Best regards

Thomas

1 Like

going wtih Tom’s idea, and tweaking the earlier code a bit:

``````import torch
import numpy as np

torch.manual_seed(4)
x = torch.rand(3,4)
y = torch.rand(3,4)
print('y', y)
_, argmax = x.max(-1)
print('argmax', argmax)
y[np.arange(3), argmax] = 3
print('y', y)
``````

Result, as required:

``````y
0.9263  0.4735  0.5949  0.7956
0.7635  0.2137  0.3066  0.0386
0.5220  0.3207  0.6074  0.5233
[torch.FloatTensor of size 3x4]

argmax
0
2
1
[torch.LongTensor of size 3]

y
3.0000  0.4735  0.5949  0.7956
0.7635  0.2137  3.0000  0.0386
0.5220  3.0000  0.6074  0.5233
[torch.FloatTensor of size 3x4]``````
3 Likes

This is for assigning the channel `1`, of which has the max value across all channels.
``````def score_max(x, dim, score):