How to advanced indexing a 4d-tensor

Hi all,

I am a little confused about the advanced indexing in a 4d-tensor,

Take an example, suppose I have a 4d-tensor x = torch.randn(10, 3, 5, 5) (10 RGB images in a mini-batch, the size is 5*5).

And I would like to sample 3 points on each 5*5 image in this mini-batch, the axis of the points are stored in two long tensors, which are row and col. Both row and col are 10 * 3 tensors. Can I index these points without using any loop like[x[i, :, a[i], b[i]] for i in range(10)]?

1 Like
N = 10
C = 3
H = 5
W = 5
x = torch.randn(N, C, H, W)
indices = torch.randint(H * W, (N, C, 1), dtype=torch.long)
torch.gather(x.view(N, C, H * W), 2, indices).view(N, C)
2 Likes