I am currently using torch.topk to determine the indices of the of a 2D tensor scores which is of size [Batch, N]. I can get the topk values (6000) from scores with torch.gather (or simply from the torch.topk directly).

idx = torch.topk(scores, 6000, dim=1, sorted=True)
scores = torch.gather(scores, dim=1, index=idx) # Output is of size [B, 6000]

My issues comes when I am trying to use the same indices on a 3D tensor bbox which is of size [Batch, N, 4]. How could I use the same indices to get something like below without having to resort to for loops

I understand that, but wouldnâ€™t that require a reshape or modification of scores to work obtain indices for both the 2D and 3D tensors, rather than a solution that works for both?