PyTorch - how to select from bounding boxes (3D tensor) based on desired indices (2D tensor)?

Suppose I have a tensor batch_bboxes of shape (BATCH_SIZE, 10000, 4), each bbox is represented by 4 float values.
I also have a tensor best_batch_idxs of shape (BATCH_SIZE, 100) where for each batch these are the 100 bboxes with the highest confidence scores that I’m trying to select.

The following program using a loop would work:

import torch


batch_bboxes = torch.randn(size=(BATCH_SIZE, 10000, 4), dtype=torch.float32)

print('\n' + 'batch_bboxes.shape = ' + str(batch_bboxes.shape) + '\n')

batch_best_idxs = torch.randint(low=0, high=10000, size=(BATCH_SIZE, 100)).type(torch.int64)

print('\n' + 'batch_sorted_idxs.shape = ' + str(batch_best_idxs.shape) + '\n')

batch_selected_bboxes = torch.full(size=(BATCH_SIZE, 100, 4), fill_value=-9999.9999, dtype=torch.float32)
for batch_idx in range(BATCH_SIZE):
    bboxes = batch_bboxes[batch_idx]
    best_idxs = batch_best_idxs[batch_idx]
    selected_bboxes = bboxes[best_idxs]
    batch_selected_bboxes[batch_idx] = selected_bboxes
# end for

print('\n' + 'batch_selected_bboxes.shape = ' + str(batch_selected_bboxes.shape) + '\n')

If I was willing to sacrifice readability I could change the for loop part to:

for batch_idx in range(BATCH_SIZE):
    batch_selected_bboxes[batch_idx] = batch_bboxes[batch_idx][batch_best_idxs[batch_idx]]
# end for

But this is still slow since it’s using a loop. Is there a way to do this without a loop, instead using some combination of torch.index_select, torch.gather, or fancy indexing? It seems it should be possible to do this without a loop but so far I haven’t been able to work out how, any suggestions?