Indexing a 2d tensor using a 2d tensor of indices

Hi, I have a quick question about indexing.

I have a 2d tensor src of size (n, 2) storing n 2d points, and another 2d tensor index of size (224, 224) storing indices. I would like to assign values to a 3d tensor output of size (224, 224, 2) so that

output[i][j] = src[index[i][j]]

It doesn’t seem like a difficult task but due to my noobness I can’t find a way to do this. I played around with scatter_ and gather but they seem to serve different purpose. Is there a simple way to do this in Pytorch? Thanks in advance for your help!

1 Like
# create dummy src and index for example
n = 5
h = 224
w = 224
src = torch.arange(1, 11).view(n, 2)
index = torch.arange(1,h * w + 1).view(h, w).remainder_(n)

# do indexing operation
output = src[index.view(-1).long(), :].view(h, w, 2)
1 Like