Tensor indexing with another tensor

Hi, I usually index tensors with lists of indices, like

x = torch.as_tensor([[1,2,3,4,5], [6,7,8,9,0]])
index = [[0, 1, 1], [1, 1, 2]]
# tensor([2, 7, 8])
x[index]

Now I need index to be a tensor object, but doing this, I get an error:

x = torch.as_tensor([[1,2,3,4,5], [6,7,8,9,0]])
index = torch.as_tensor( [[0, 1, 1], [1, 1, 2]])
# IndexError: index 2 is out of bounds for dimension 0 with size 2
x[index]

I don’t know if it is expected to work differently, but a numpy array behaves like the list:

x = torch.as_tensor([[1,2,3,4,5], [6,7,8,9,0]])
index = [[0, 1, 1], [1, 1, 2]]
index_n = numpy.asarray(index)
index_t = torch.as_tensor(index)

# tensor([2, 7, 8])
x[index]
# tensor([2, 7, 8])
x[index_n]
# IndexError: index 2 is out of bounds for dimension 0 with size 2
x[index_t]

How can I get the same output using a tensor instead of a list?

ind = index[0,:] * x.size(1) + index[1,:]
torch.take(x, ind)

Thank you for your answer. I realize that my example was too simple, because I need to do this on 3D tensors. My solution for now is:

x = torch.as_tensor(
        [
            [[1, 2, 3, 4, 5], [6, 7, 8, 9, 0], [11, 12, 13, 14, 15]],
            [[16, 17, 18, 19, 20], [21, 22, 23, 24, 25], [26, 27, 28, 29, 30]],
        ]
    )

index = [[0, 0, 1], [1, 2, 0]]
# tensor([[ 6,  7,  8,  9,  0],
#         [11, 12, 13, 14, 15],
#         [16, 17, 18, 19, 20]])
x[index]

index_t = torch.as_tensor(index)
# tensor([[ 6,  7,  8,  9,  0],
#         [11, 12, 13, 14, 15],
#         [16, 17, 18, 19, 20]])
x = x.index_select(0, index_t[0])
x = x[torch.arange(x.shape[0]).unsqueeze(-1), index_t[1].unsqueeze(-1)].squeeze()

It is a bit convoluted, but it gets the job done. If there is a better way to do it, I will be happy to learn :slight_smile:

ind = index[0,:] * x.size(1) + index[1,:]
torch.take(x, ind)

In the above issue raised in the github repository of pytorch, the person who opened the issue has replaced torch.Tensor.__getitem__ with their own lambda.

I made the same change as follows:

f = torch.Tensor.__getitem__
g = lambda *a, **kw: f(*a, **kw)
torch.Tensor.__getitem__ = g

And now I tried indexing x with index_t, and got the following output:

image

Does it break other tensor functionalities?

__getitem__ only affects indexing. But, as far as I have checked, it works as expected for Lists, Numpy Arrays and Tensors. If you find any bugs with this implementation, i’ll be curious to know what that is :slight_smile: