I also have the same question.
To clarify, assuming I have a tensor A which A.shape = 2 x 3 x 2 (B x N x C)
A = torch.tensor(
[[[ 0.4930, 0.1150],
[-0.2355, 1.1917],
[-1.2421, -0.4383]],
[[ 0.3099, 3.4751],
[ 0.7780, 1.0990],
[-0.0795, 0.1633]]])
And an indices tensor idx where idx.shape = 2 x 3 x 1 (B x N x K)
idx = tensor([[[0],
[1],
[2]],
[[1],
[2],
[1]]])
I would like to get a 4D tensor res that res.shape = 2 x 3 x 1 x 2 (B x N x K x C)
res = tensor(
[[[[ 0.4930, 0.1150]],
[[-0.2355, 1.1917]],
[[-1.2421, -0.4383]]],
[[[ 0.7780, 1.0990]],
[[-0.0795, 0.1633]],
[[ 0.7780, 1.0990]]]])
Currently, I implemented it
res = torch.cat([A[i, idx[i]] for i in range(B)], dim=0).view(B, N, K, C)
I am wondering are there more efficient ways to do this? (e.g. Get rid of the for loop)
Thanks in advance!