# Select sub-tensors via indices from torch.topk

I cannot seem to produce a nice implementation of indexing with a different list of indices with different elements in each dimension. This is a common use-case, I bet there is some nice solution to that:

Assume the following:

import torch

# I have following quantities
batch_size = 6
L = 100
K = 10
d = 300

# I have a matrix of vectors
# batch X L x d
R = torch.randn(batch_size, L, d)

# I have a matrix of scores (each element of this is a score assigned to vector from matrix R via some transformation)
# batch_size x L
scores = torch.randn(batch_size, L)

# I get the top indices of each
# batch x K
topk_positions = torch.topk(scores, K, dim=1)[1]

# example of topk_positions contents
# tensor([[108,  16,  15, 107,  22, 106,  19,  11,  20,  99],
#         [118, 113, 167, 112, 108, 171, 111, 117, 138,  18],
#         [ 43,  54,  44,  35,  36,  40,  55,  51,  10,  47],
#         [ 89,  84,  78,  83,  82,  92,  67,  72,  86,  85],
#         [ 57,  46,  58,  50,  47,  38,  39,  43,  54,  13],
#         [ 46,  57,  47,  38,  39,  54,  58,  43,  13,  50]])

# now I would like to index these positions into matrix R, so I will get the vectors corresponding to the top elements
# the resulting matrix should be of shape batch_size x K x d

R_best = ?

# it should be equivalent to
R_best_ugly = torch.zeros((batch_size, K, d))
for i, batch_size in enumerate(topk_positions):  # over batch
R_best_ugly[i] = R[i][topk_positions[i]]

assert R_best.allcose(R_best_ugly)

Finally, solved it thanks to my colleague!

R_best = R.gather(1,topk_positions.unsqueeze(2).expand(batch_size,K,d))