Find max row index with some condition

I have input tensor (N*2) , first column is category and second is the value, something like this:

[ [14,50] , [18,1.5] , [14,250] , [18,2.5] , [10,1] , [14,252] , [18,5.3] ]

each category may be have N row in input tensor,
how to found index of row in input tensor that have max value in each category, something like this for above input tensor

[ [14,5] , [18,6] , [10,4] ]

[ category code, row index have max value ]

but if it’s possible without using loop,
thank’s

There is no groupby function in pytorch. So not possible without a for loop. You can convert to numpy and then use pandas groupby.

yes in fact I search something like group by in PyTorch and do not find,
Thank’s for reply Kushaj

1 Like