Hi,
I thought first that using the argmax method on a tensor as a mask for the same tensor would be identical as the output of the max method. I was wrong.
I want to get the value of the index for the tensor and for an other tensor, so I really need to be able to use the argmax as indexes.
right now I achieved it for a matrix :
import torch
a = torch.arange(50).view(5,10)
b = torch.arange(50).view(5,10)
cx,c = b.max(dim=0)
assert (b[c,torch.arange(10)] == cx).all()
a = torch.arange(50).view(1,5,10)
b = torch.arange(50).view(1,5,10)
cx,c = b.max(dim=1)
assert (b[:,c,torch.arange(10)] == cx).all()
a = torch.arange(250).view(1,5,5,10)
b = torch.arange(250).view(1,5,5,10)
cx,c = b.max(dim=1)
assert (b[:,c,[torch.arange(5),torch.arange(10)] ] == cx).all()
I get the following error in the last assert :
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-118-390e76989bea> in <module>
----> 1 b[:,c,[torch.arange(5),torch.arange(10)] ] == cx
TypeError: only integer tensors of a single element can be converted to an index
I am a bit stuck and I don’t want to use python loop ( for speed purpose).
Thanks in advance