Get 2d indices for max and min

Hi
lets say I have a tensor of shape [batchsize , 7 , 7]
I wish to get for each sample in the batch the maximum two values with 2d indices + the minimum two values with 2d indices

Example : lets say we have a tensor of shape [2 , 3 , 3]

tensor = [[[1,2,3] , [4,5,6] , [6,7,8]] , 
         [[-9,-8,-7],[-6,-5,-4],[-3,-2,-1]]]

the maximum is : [(2,2) , (2,1)] , [(2,2) , (2,1)] for the two tensors,
the minimum is : [(0 , 0) , (0,1)] , [(0,0) , (0,1)] for the two tensors as well.

I have tried something around :

torch.nonzero ((a == a.max (dim = 1, keepdim = True)[0]))

but it works over rows/columns not on the whole tensor
thanks

maximum

index = torch.topk(tensor.reshape(2, -1), 2, dim=1)[1][..., None]
a, b = index // 3, index % 3
res = torch.cat([a, b], dim=2)
print(res)

for minimum, use tensor = -tensor

1 Like