You can do this using:
def batched_index_select(input, dim, index):
for ii in range(1, len(input.shape)):
if ii != dim:
index = index.unsqueeze(ii)
expanse = list(input.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.expand(expanse)
return torch.gather(input, dim, index)
Taken from here.