Lets say that we have a tensor of shape [N, C, D], where N = Batch, C = Channels, and D = H*W. For this example we can use a tensor of shape [10, 64, 9].
Lets say I want to index each input element along the D dimension with 4 unique indices per element (for example, [2, 5, 7, 8] for batch element one, [1, 2, 4, 6] for batch element two, … , [0, 7, 2, 8] for batch element ten) to yield an output tensor of size [10, 64, 4].
Is there a way to accomplish this without looping? I believe I would have to make use of torch.index_select somehow, but am not entirely certain. It would also be great if I could accomplish this in-place. Here is a code snippet of me solving this for just one batch element (and because I could really use the help and dont want people to think this is a low effort post):
# Where we get our indices
>>>temp.shape
torch.Size([32768, 9])
# Our tensor of interest
>>>temp2.shape
torch.Size([32768, 64, 9])
# Generating our indices
>>>torch.topk(temp, 4).indices.shape
torch.Size([32768, 4])
# Generating our outputs for just one batch element.
>>>torch.index_select(temp2[1], 1, torch.topk(temp, 4).indices[1]).shape
torch.Size([64, 4])
# How can we do this for all of the batch elements without looping/out-of-place operations?