Get ALL indices of the maximum of a tensor

Hi everyone,
I am looking for a possibility to get ALL indices for the maximum for each row in a 2-D tensor.
Let us define the following tensor:

a = torch.tensor([[5,8,9,6,9], [5,4,3,2,1]])

I already tried torch.max and torch.argmax, but I didn’t get the desired output:

torch.max(a, dim=1) # returns tensor([4,0])
max_val, idx = torch.max(a, dim=1, keepdim=True) # max_val is tensor([[9],[5]]) and idx is tensor([[4],[0]])

I am looking for a function which returns ALL indices of the max value for each column. So I expect the following tensor, since the 9 occurs twice in the first row and the 5 only one time in the second row.

torch.tensor([[2,4],[0]])

Anyone has an idea?
Thanks in advance!

Hi Roxor!

Try this:

>>> torch.__version__
'1.7.1'
>>> a = torch.tensor([[5,8,9,6,9], [5,4,3,2,1]])
>>> torch.nonzero ((a == a.max (dim = 1, keepdim = True)[0]))
tensor([[0, 2],
        [0, 4],
        [1, 0]])

Note, that pytorch does not support “ragged tensors” (tensors whose rows
are not equal to one another in length), so you can’t get your result in a
format like this.

Best.

K. Frank