I cannot seem to produce a nice implementation of indexing with a different list of indices with different elements in each dimension. This is a common use-case, I bet there is some nice solution to that:
Assume the following:
import torch
# I have following quantities
batch_size = 6
L = 100
K = 10
d = 300
# I have a matrix of vectors
# batch X L x d
R = torch.randn(batch_size, L, d)
# I have a matrix of scores (each element of this is a score assigned to vector from matrix R via some transformation)
# batch_size x L
scores = torch.randn(batch_size, L)
# I get the top indices of each
# batch x K
topk_positions = torch.topk(scores, K, dim=1)[1]
# example of topk_positions contents
# tensor([[108, 16, 15, 107, 22, 106, 19, 11, 20, 99],
# [118, 113, 167, 112, 108, 171, 111, 117, 138, 18],
# [ 43, 54, 44, 35, 36, 40, 55, 51, 10, 47],
# [ 89, 84, 78, 83, 82, 92, 67, 72, 86, 85],
# [ 57, 46, 58, 50, 47, 38, 39, 43, 54, 13],
# [ 46, 57, 47, 38, 39, 54, 58, 43, 13, 50]])
# now I would like to index these positions into matrix R, so I will get the vectors corresponding to the top elements
# the resulting matrix should be of shape batch_size x K x d
R_best = ?
# it should be equivalent to
R_best_ugly = torch.zeros((batch_size, K, d))
for i, batch_size in enumerate(topk_positions): # over batch
R_best_ugly[i] = R[i][topk_positions[i]]
assert R_best.allcose(R_best_ugly)