Lets say that we have a tensor of shape [N, C, D], where N = Batch, C = Channels, and D = H*W. For this example we can use a tensor of shape [10, 64, 9].
Lets say I want to index each input element along the D dimension with 4 unique indices per element (for example, [2, 5, 7, 8] for batch element one, [1, 2, 4, 6] for batch element two, … , [0, 7, 2, 8] for batch element ten) to yield an output tensor of size [10, 64, 4].
Is there a way to accomplish this without looping? I believe I would have to make use of torch.index_select somehow, but am not entirely certain. It would also be great if I could accomplish this in-place. Here is a code snippet of me solving this for just one batch element (and because I could really use the help and dont want people to think this is a low effort post):
# Where we get our indices >>>temp.shape torch.Size([32768, 9]) # Our tensor of interest >>>temp2.shape torch.Size([32768, 64, 9]) # Generating our indices >>>torch.topk(temp, 4).indices.shape torch.Size([32768, 4]) # Generating our outputs for just one batch element. >>>torch.index_select(temp2, 1, torch.topk(temp, 4).indices).shape torch.Size([64, 4]) # How can we do this for all of the batch elements without looping/out-of-place operations?