# Is there a better way to perform row-corresponding indexing?

So let’s say I have array `A` and the indexing array `idxs` as follows

``````A = torch.Tensor([[8, 3, 5],
[7, 6, 1],
[2, 4, 9]])
idxs = torch.Tensor([[0, 2],
[1, 2],
[2, 0]]).int()
``````

I want to find a fast method of indexing where the first row of `idxs` indexes the first row of `A`, the second row of `idxs` indexes the second row, and so on. In this way I’d have the following output

``````[[8, 5],
[6, 1],
[9, 2]]
``````

I have a solution as shown below

``````A[:, idxs][torch.eye(len(A))]
``````

Which works quite well, but I am worried that if `A` has tens of thousands of rows there will be memory issues due to the `torch.eye` function.

1. Is pytorch’s storage of tensors sufficiently efficient that there should be no issues at all? Or,
3. Regardless of both questions, is there a better indexing method than what I have shown?

Your code doesn’t work for me and fails with:

``````A = torch.Tensor([[8, 3, 5],
[7, 6, 1],
[2, 4, 9]])
idxs = torch.Tensor([[0, 2],
[1, 2],
[2, 0]]).int()
A[:, idxs][torch.eye(len(A))]
# IndexError: tensors used as indices must be long, int, byte or bool tensors
``````

Fixing this by calling `torch.eye(...).long()` gives:

``````A[:, idxs][torch.eye(len(A)).long()]
# tensor([[[[7., 1.],
#           [6., 1.],
#           [1., 7.]],

#          [[8., 5.],
#           [3., 5.],
#           [5., 8.]],

#          [[8., 5.],
#           [3., 5.],
#           [5., 8.]]],

#         [[[8., 5.],
#           [3., 5.],
#           [5., 8.]],

#          [[7., 1.],
#           [6., 1.],
#           [1., 7.]],

#          [[8., 5.],
#           [3., 5.],
#           [5., 8.]]],

#         [[[8., 5.],
#           [3., 5.],
#           [5., 8.]],

#          [[8., 5.],
#           [3., 5.],
#           [5., 8.]],

#          [[7., 1.],
#           [6., 1.],
#           [1., 7.]]]])
``````

so also not your expected output, I’m thus unsure how your approach worked.

In any case, `torch.gather` should work:

``````A = torch.tensor([[8, 3, 5],
[7, 6, 1],
[2, 4, 9]])
idxs = torch.tensor([[0, 2],
[1, 2],
[2, 0]])
torch.gather(A, 1, idxs)
# tensor([[8, 5],
#         [6, 1],
#         [9, 2]])
``````
1 Like