Indexing a tensor with unique indices along each element of a batch

Lets say that we have a tensor of shape [N, C, D], where N = Batch, C = Channels, and D = H*W. For this example we can use a tensor of shape [10, 64, 9].

Lets say I want to index each input element along the D dimension with 4 unique indices per element (for example, [2, 5, 7, 8] for batch element one, [1, 2, 4, 6] for batch element two, … , [0, 7, 2, 8] for batch element ten) to yield an output tensor of size [10, 64, 4].

Is there a way to accomplish this without looping? I believe I would have to make use of torch.index_select somehow, but am not entirely certain. It would also be great if I could accomplish this in-place. Here is a code snippet of me solving this for just one batch element (and because I could really use the help and dont want people to think this is a low effort post):

# Where we get our indices
>>>temp.shape
torch.Size([32768, 9])
# Our tensor of interest
>>>temp2.shape
torch.Size([32768, 64, 9])
# Generating our indices
>>>torch.topk(temp, 4).indices.shape
torch.Size([32768, 4])
# Generating our outputs for just one batch element.
>>>torch.index_select(temp2[1], 1, torch.topk(temp, 4).indices[1]).shape
torch.Size([64, 4])
# How can we do this for all of the batch elements without looping/out-of-place operations?


This appears to have been solved here. This is however not in-place as far as I can tell.

The solution:

def batched_index_select(input, dim, index):
    for ii in range(1, len(input.shape)):
        if ii != dim:
            index = index.unsqueeze(ii)
    expanse = list(input.shape)
    expanse[0] = -1
    expanse[dim] = -1
    index = index.expand(expanse)
    return torch.gather(input, dim, index)

If anyone can figure out an in-place solution it would be deeply appreciated, I will also keep trying and report back here if I make any progress.