Indexing a tensor with unique indices along each element of a batch

This appears to have been solved here. This is however not in-place as far as I can tell.

The solution:

def batched_index_select(input, dim, index):
    for ii in range(1, len(input.shape)):
        if ii != dim:
            index = index.unsqueeze(ii)
    expanse = list(input.shape)
    expanse[0] = -1
    expanse[dim] = -1
    index = index.expand(expanse)
    return torch.gather(input, dim, index)

If anyone can figure out an in-place solution it would be deeply appreciated, I will also keep trying and report back here if I make any progress.