Ask for help! I want to know a best practice about torch.Tensor

I have a Tensor A(256,512,1024) and another index Tensor Index(256,960,2). For example, [1,2] in Index means I want to retrive the 1st and 2nd tensor of A in dimension 1. And finally I can get a Tensor Res(256,960,2,1024).

I can implement only by iteratate all elements. Please help me if you have any good practice.

to simplify my model:

import torch

if __name__ == '__main__':
    src = torch.FloatTensor([
        [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
    ])
    idx = torch.LongTensor([
        [[0, 1], [1, 2], [0, 2], [1, 0], [2, 1]]
    ])
    res = torch.zeros((1, 5, 2, 4))
    # I want (1,5,2,4)
    # [ [[1,2,3,4],[5,6,7,8]],[[5,6,7,8],[9,10,11,12]],[...],[...],[....] ]

I believe that is done by:

A.gather(-2, I.view(256,960*2,1).expand(-1,-1,1024)).view(256,960,2,1024)

1 Like