Suppose I have a tensor batch_bboxes
of shape (BATCH_SIZE, 10000, 4)
, each bbox is represented by 4 float values.
I also have a tensor best_batch_idxs
of shape (BATCH_SIZE, 100)
where for each batch these are the 100 bboxes with the highest confidence scores that I’m trying to select.
The following program using a loop would work:
import torch
BATCH_SIZE = 8
batch_bboxes = torch.randn(size=(BATCH_SIZE, 10000, 4), dtype=torch.float32)
print('\n' + 'batch_bboxes.shape = ' + str(batch_bboxes.shape) + '\n')
batch_best_idxs = torch.randint(low=0, high=10000, size=(BATCH_SIZE, 100)).type(torch.int64)
print('\n' + 'batch_sorted_idxs.shape = ' + str(batch_best_idxs.shape) + '\n')
batch_selected_bboxes = torch.full(size=(BATCH_SIZE, 100, 4), fill_value=-9999.9999, dtype=torch.float32)
for batch_idx in range(BATCH_SIZE):
bboxes = batch_bboxes[batch_idx]
best_idxs = batch_best_idxs[batch_idx]
selected_bboxes = bboxes[best_idxs]
batch_selected_bboxes[batch_idx] = selected_bboxes
# end for
print('\n' + 'batch_selected_bboxes.shape = ' + str(batch_selected_bboxes.shape) + '\n')
If I was willing to sacrifice readability I could change the for loop part to:
for batch_idx in range(BATCH_SIZE):
batch_selected_bboxes[batch_idx] = batch_bboxes[batch_idx][batch_best_idxs[batch_idx]]
# end for
But this is still slow since it’s using a loop. Is there a way to do this without a loop, instead using some combination of torch.index_select
, torch.gather
, or fancy indexing? It seems it should be possible to do this without a loop but so far I haven’t been able to work out how, any suggestions?