import torch
def batch_wise_index_select(input, dim, index):
assert input.size(0) == index.size(0)
out = []
for i in range(input.size(0)):
out.append(torch.index_select(input[i], dim=dim-1, index=index[i].view(-1)))
return torch.stack(out, dim=0)
a = (torch.rand(2,2,5)*10).int()
b = (torch.rand(2,3)*3).long()
out = batch_wise_index_select(a, dim=2, index=b)
print(a)
print(b)
print(out)
Output
tensor([[[6, 4, 7, 4, 0],
[2, 3, 1, 4, 5]],
[[6, 2, 6, 6, 3],
[7, 4, 7, 2, 9]]], dtype=torch.int32)
tensor([[1, 1, 2],
[1, 1, 0]])
tensor([[[4, 4, 7],
[3, 3, 1]],
[[2, 2, 6],
[4, 4, 7]]], dtype=torch.int32)
Then how to accelerate batch_wise_index_select ?