Use Argmax to index tensor

How can I use the argmax values to index a tensor?

So, for example, I have two tensors of the same shape x,y and have the argmax = x.min(-1) of one of them. Then I want to get the values at the position in y i.e. y[argmax] ?

How can I do that ?

Your argmax should rather be:

argmax = x.max(0)[1]

since max returns a tuple.

Then you can use it as you did:

y[argmax]
x = torch.rand(3,4)
y = torch.rand(3,4)
argmax = x.max(0)[1]
y[argmax]

Does not work for me.

You probably want to use arange(4) as a second index or use gather or so.

Best regards

Thomas

1 Like

going wtih Tom’s idea, and tweaking the earlier code a bit:

import torch
import numpy as np

torch.manual_seed(4)
x = torch.rand(3,4)
y = torch.rand(3,4)
print('y', y)
_, argmax = x.max(-1)
print('argmax', argmax)
y[np.arange(3), argmax] = 3
print('y', y)

Result, as required:

y 
 0.9263  0.4735  0.5949  0.7956
 0.7635  0.2137  0.3066  0.0386
 0.5220  0.3207  0.6074  0.5233
[torch.FloatTensor of size 3x4]

argmax 
 0
 2
 1
[torch.LongTensor of size 3]

y 
 3.0000  0.4735  0.5949  0.7956
 0.7635  0.2137  3.0000  0.0386
 0.5220  3.0000  0.6074  0.5233
[torch.FloatTensor of size 3x4]
3 Likes

Seems this will help you. https://github.com/Zhaoyi-Yan/Shift-Net_pytorch/blob/master/util/MaxCoord.py
This is for assigning the channel 1, of which has the max value across all channels.

Using torch.gather could work. Yet, it will first generate a tensor with the same size as the original tensor. If any one have a better way?

def score_max(x, dim, score):
    _tmp=[1]*len(x.size())
    _tmp[dim] = x.size(dim)
    return torch.gather(x,dim,score.max(
        dim)[1].unsqueeze(dim).repeat(tuple(_tmp))).select(dim,0)