How can I use the argmax values to index a tensor?

So, for example, I have two tensors of the same shape x,y and have the argmax = x.min(-1) of one of them. Then I want to get the values at the position in y i.e. y[argmax] ?

How can I do that ?

How can I use the argmax values to index a tensor?

So, for example, I have two tensors of the same shape x,y and have the argmax = x.min(-1) of one of them. Then I want to get the values at the position in y i.e. y[argmax] ?

How can I do that ?

Your argmax should rather be:

```
argmax = x.max(0)[1]
```

since `max`

returns a tuple.

Then you can use it as you did:

```
y[argmax]
```

```
x = torch.rand(3,4)
y = torch.rand(3,4)
argmax = x.max(0)[1]
y[argmax]
```

Does not work for me.

You probably want to use arange(4) as a second index or use gather or so.

Best regards

Thomas

1 Like

going wtih Tom’s idea, and tweaking the earlier code a bit:

```
import torch
import numpy as np
torch.manual_seed(4)
x = torch.rand(3,4)
y = torch.rand(3,4)
print('y', y)
_, argmax = x.max(-1)
print('argmax', argmax)
y[np.arange(3), argmax] = 3
print('y', y)
```

Result, as required:

```
y
0.9263 0.4735 0.5949 0.7956
0.7635 0.2137 0.3066 0.0386
0.5220 0.3207 0.6074 0.5233
[torch.FloatTensor of size 3x4]
argmax
0
2
1
[torch.LongTensor of size 3]
y
3.0000 0.4735 0.5949 0.7956
0.7635 0.2137 3.0000 0.0386
0.5220 3.0000 0.6074 0.5233
[torch.FloatTensor of size 3x4]
```

3 Likes

Seems this will help you. https://github.com/Zhaoyi-Yan/Shift-Net_pytorch/blob/master/util/MaxCoord.py

This is for assigning the channel `1`

, of which has the max value across all channels.

Using torch.gather could work. Yet, it will first generate a tensor with the same size as the original tensor. If any one have a better way?

```
def score_max(x, dim, score):
_tmp=[1]*len(x.size())
_tmp[dim] = x.size(dim)
return torch.gather(x,dim,score.max(
dim)[1].unsqueeze(dim).repeat(tuple(_tmp))).select(dim,0)
```