Hi
lets say I have a tensor of shape [batchsize , 7 , 7]
I wish to get for each sample in the batch the maximum two values with 2d indices + the minimum two values with 2d indices
Example : lets say we have a tensor of shape [2 , 3 , 3]
tensor = [[[1,2,3] , [4,5,6] , [6,7,8]] ,
[[-9,-8,-7],[-6,-5,-4],[-3,-2,-1]]]
the maximum is : [(2,2) , (2,1)] , [(2,2) , (2,1)] for the two tensors,
the minimum is : [(0 , 0) , (0,1)] , [(0,0) , (0,1)] for the two tensors as well.
I have tried something around :
torch.nonzero ((a == a.max (dim = 1, keepdim = True)[0]))
but it works over rows/columns not on the whole tensor
thanks