[pytorch] 2D tensor indexing by 2D long tensor raised error

What I want to do is gathering elements from 10x20 tensor x according to the class label tensor y whose shape is (10,).

I couldn’t understand this error because I didn’t specify index 19 in dim 0.

Could anyone help me?

❯ ipython
Python 3.6.5 |Anaconda custom (64-bit)| (default, Apr 29 2018, 16:14:56)
Type 'copyright', 'credits' or 'license' for more information
IPython 6.5.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch

In [2]: x = torch.empty(10, 20).uniform_().float(); y = torch.randint(0, 20, (10,)).long()

In [3]: indices = torch.stack((torch.arange(x.size(0)), y), 1)

In [4]: x[indices]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-a9c39fcc26e5> in <module>()
----> 1 x[indices]

RuntimeError: index 19 is out of bounds for dimension 0 with size 10

In [5]: indices
Out[5]:
tensor([[ 0, 18],
        [ 1, 10],
        [ 2,  4],
        [ 3, 13],
        [ 4,  7],
        [ 5,  5],
        [ 6,  0],
        [ 7,  8],
        [ 8, 19],
        [ 9,  8]])
In [6]: torch.__version__
Out[6]: '1.0.0.dev20180921'

In [8]: x[9, 19]
Out[8]: tensor(0.8050)  <- of course, this worked.

Hi,

You don’t actually need to create 2D indices, you can use gather for that:

import torch

x = torch.empty(10, 20).uniform_().float(); y = torch.randint(0, 20, (10,)).long()
print(x)
print(y)

out = x.gather(1, y.unsqueeze(-1))
print(out)
1 Like

Oh, appreciate this so much!
I love this simple, readable way!!!

Do you have any idea why my method failed?

You should not give it as a 2D tensor, but as two 1D tensors: out = x[torch.arange(x.size(0)), y].

1 Like