I have a Tensor A(256,512,1024) and another index Tensor Index(256,960,2). For example, [1,2] in Index means I want to retrive the 1st and 2nd tensor of A in dimension 1. And finally I can get a Tensor Res(256,960,2,1024).
I can implement only by iteratate all elements. Please help me if you have any good practice.
to simplify my model:
import torch
if __name__ == '__main__':
src = torch.FloatTensor([
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
])
idx = torch.LongTensor([
[[0, 1], [1, 2], [0, 2], [1, 0], [2, 1]]
])
res = torch.zeros((1, 5, 2, 4))
# I want (1,5,2,4)
# [ [[1,2,3,4],[5,6,7,8]],[[5,6,7,8],[9,10,11,12]],[...],[...],[....] ]