# Get ALL indices of the maximum of a tensor

Hi everyone,
I am looking for a possibility to get ALL indices for the maximum for each row in a 2-D tensor.
Let us define the following tensor:

``````a = torch.tensor([[5,8,9,6,9], [5,4,3,2,1]])
``````

I already tried torch.max and torch.argmax, but I didn’t get the desired output:

``````torch.max(a, dim=1) # returns tensor([4,0])
max_val, idx = torch.max(a, dim=1, keepdim=True) # max_val is tensor([[9],[5]]) and idx is tensor([[4],[0]])
``````

I am looking for a function which returns ALL indices of the max value for each column. So I expect the following tensor, since the 9 occurs twice in the first row and the 5 only one time in the second row.

``````torch.tensor([[2,4],[0]])
``````

Anyone has an idea?

Hi Roxor!

Try this:

``````>>> torch.__version__
'1.7.1'
>>> a = torch.tensor([[5,8,9,6,9], [5,4,3,2,1]])
>>> torch.nonzero ((a == a.max (dim = 1, keepdim = True)[0]))
tensor([[0, 2],
[0, 4],
[1, 0]])
``````

Note, that pytorch does not support “ragged tensors” (tensors whose rows
are not equal to one another in length), so you can’t get your result in a
format like this.

Best.

K. Frank