Accelerate this function `batch-wise index_select `


import torch
def batch_wise_index_select(input, dim, index):
     assert input.size(0) == index.size(0)
     out = []
     for i in range(input.size(0)):
         out.append(torch.index_select(input[i], dim=dim-1, index=index[i].view(-1)))
     return torch.stack(out, dim=0)

a = (torch.rand(2,2,5)*10).int()
b = (torch.rand(2,3)*3).long()
out = batch_wise_index_select(a, dim=2, index=b)

print(a)
print(b)
print(out)

Output

tensor([[[6, 4, 7, 4, 0],
         [2, 3, 1, 4, 5]],

        [[6, 2, 6, 6, 3],
         [7, 4, 7, 2, 9]]], dtype=torch.int32)
tensor([[1, 1, 2],
        [1, 1, 0]])
tensor([[[4, 4, 7],
         [3, 3, 1]],

        [[2, 2, 6],
         [4, 4, 7]]], dtype=torch.int32)

Then how to accelerate batch_wise_index_select ?