Use argmax as index for a new array


I thought first that using the argmax method on a tensor as a mask for the same tensor would be identical as the output of the max method. I was wrong.

I want to get the value of the index for the tensor and for an other tensor, so I really need to be able to use the argmax as indexes.

right now I achieved it for a matrix :

import torch

a = torch.arange(50).view(5,10)
b = torch.arange(50).view(5,10)
cx,c = b.max(dim=0)

assert (b[c,torch.arange(10)] == cx).all()

a = torch.arange(50).view(1,5,10)
b = torch.arange(50).view(1,5,10)
cx,c = b.max(dim=1)

assert (b[:,c,torch.arange(10)] == cx).all()

a = torch.arange(250).view(1,5,5,10)
b = torch.arange(250).view(1,5,5,10)
cx,c = b.max(dim=1)

assert (b[:,c,[torch.arange(5),torch.arange(10)] ] == cx).all()

I get the following error in the last assert :

TypeError                                 Traceback (most recent call last)
<ipython-input-118-390e76989bea> in <module>
----> 1 b[:,c,[torch.arange(5),torch.arange(10)] ] == cx

TypeError: only integer tensors of a single element can be converted to an index

I am a bit stuck and I don’t want to use python loop ( for speed purpose).

Thanks in advance :slight_smile:



it works thanks :slight_smile: