Groupby aggregated topk in pytorch

Hi there,

bit of an odd question but I’m wondering if this is possible to do with pytorch out of the box.
Basically I’m trying to replicate an operation I can do quite easily using pandas groupby operations.

Let’s say I have a tensor that’s full of people’s ages in a particular class which is simply identified by an integer. So in the following example the class identifiers are in column 0: 99 and 55. Note they are of different sizes. The ages are in column 1.

ages_by_class = [[ 99,  24 ],
                 [ 99,  13 ],
                 [ 55,  33 ],  #<--- ages not necessarily sorted in any order apriori
                 [ 55,  43 ],
                 [ 55,  36 ],

I’m trying to get the indexes or boolean mask or values corresponding to the topk ages within each group. So for the above, the boolean mask would look something like

mask_top2 := [[ 1 ],
              [ 1 ],
              [ 0 ],
              [ 1 ],
              [ 1 ],

In pandas the solution is simply to groupby() the sorted class id column, and then get the .head(k) but it’s not clear how I could translate this logic into pytorch. Any help would be greatly appreciated!

cheers