Hi everyone,
I am looking for a possibility to get ALL indices for the maximum for each row in a 2-D tensor.
Let us define the following tensor:
a = torch.tensor([[5,8,9,6,9], [5,4,3,2,1]])
I already tried torch.max and torch.argmax, but I didn’t get the desired output:
torch.max(a, dim=1) # returns tensor([4,0])
max_val, idx = torch.max(a, dim=1, keepdim=True) # max_val is tensor([[9],[5]]) and idx is tensor([[4],[0]])
I am looking for a function which returns ALL indices of the max value for each column. So I expect the following tensor, since the 9 occurs twice in the first row and the 5 only one time in the second row.
torch.tensor([[2,4],[0]])
Anyone has an idea?
Thanks in advance!