How to indexing

Define variables…

>>> a = torch.randn(3, 5)
>>> b = torch.LongTensor([0, 1, 2])
>>> a

 0.0215  0.8084 -1.2216  0.3401  0.1423
-1.5521 -0.5513 -0.5121 -1.5491 -0.9890
-1.2377  1.5195 -0.1469  1.8054 -0.8315
[torch.FloatTensor of size 3x5]

>>> b

 0
 1
 2
[torch.LongTensor of size 3]

I would like to get a Tensor like below,

0.0215 -0.5513 -0.1469

so I did like this,

>>> a.index_select(dim=0, index=b)

 0.0215  0.8084 -1.2216  0.3401  0.1423
-1.5521 -0.5513 -0.5121 -1.5491 -0.9890
-1.2377  1.5195 -0.1469  1.8054 -0.8315
[torch.FloatTensor of size 3x5]

>>> a.index_select(dim=1, index=b)

 0.0215  0.8084 -1.2216
-1.5521 -0.5513 -0.5121
-1.2377  1.5195 -0.1469
[torch.FloatTensor of size 3x3]

These are not my desired tensor.
How can I get my objective one?

If you really just want the indices [0, 1, 2], a.diag() will give you what you want.
If the indices can be other things than the “diagonal” elements, a.gather(1, b.unsqueeze(1)) will give you what you want.

1 Like

Sorry for ambiguous qustion. Actually, I would like to obtain values designated by b, not diag elements.
But, thanks!

In that case, a.gather(1, b.unsqueeze(1)) is what you’re looking for !

1 Like

Oh, I overlooked that. I was able to get my target!
Thank you!