Pytorch operation:get tensor from other index tensor

tensor A:shape (1024 128 3)
tensor B: shape (1024 10)(the value in tensor B is index number of A at dimension 1)
tensor C:shape (1024 10 3)
how can I get tensor C from A with index B

This should work:

A = torch.randn(1024, 128, 3)
B = torch.randint(0, 128, (1024, 10))
C = A[torch.arange(A.size(0)).unsqueeze(1), B]
print(C.shape)
# torch.Size([1024, 10, 3])
1 Like