This appears to have been solved here. This is however not in-place as far as I can tell.
The solution:
def batched_index_select(input, dim, index):
for ii in range(1, len(input.shape)):
if ii != dim:
index = index.unsqueeze(ii)
expanse = list(input.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.expand(expanse)
return torch.gather(input, dim, index)
If anyone can figure out an in-place solution it would be deeply appreciated, I will also keep trying and report back here if I make any progress.