Is there a better way to perform row-corresponding indexing?

So let’s say I have array A and the indexing array idxs as follows

A = torch.Tensor([[8, 3, 5],
                  [7, 6, 1],
                  [2, 4, 9]])
idxs = torch.Tensor([[0, 2],
                     [1, 2],
                     [2, 0]]).int()

I want to find a fast method of indexing where the first row of idxs indexes the first row of A, the second row of idxs indexes the second row, and so on. In this way I’d have the following output

[[8, 5],
 [6, 1],
 [9, 2]]

I have a solution as shown below

A[:, idxs][torch.eye(len(A))]

Which works quite well, but I am worried that if A has tens of thousands of rows there will be memory issues due to the torch.eye function.

  1. Is pytorch’s storage of tensors sufficiently efficient that there should be no issues at all? Or,
  2. Should I be worried about this and change my solution to something better?
  3. Regardless of both questions, is there a better indexing method than what I have shown?

Thanks in advance!

Your code doesn’t work for me and fails with:

A = torch.Tensor([[8, 3, 5],
                  [7, 6, 1],
                  [2, 4, 9]])
idxs = torch.Tensor([[0, 2],
                     [1, 2],
                     [2, 0]]).int()
A[:, idxs][torch.eye(len(A))]
# IndexError: tensors used as indices must be long, int, byte or bool tensors

Fixing this by calling torch.eye(...).long() gives:

A[:, idxs][torch.eye(len(A)).long()]
# tensor([[[[7., 1.],
#           [6., 1.],
#           [1., 7.]],

#          [[8., 5.],
#           [3., 5.],
#           [5., 8.]],

#          [[8., 5.],
#           [3., 5.],
#           [5., 8.]]],


#         [[[8., 5.],
#           [3., 5.],
#           [5., 8.]],

#          [[7., 1.],
#           [6., 1.],
#           [1., 7.]],

#          [[8., 5.],
#           [3., 5.],
#           [5., 8.]]],


#         [[[8., 5.],
#           [3., 5.],
#           [5., 8.]],

#          [[8., 5.],
#           [3., 5.],
#           [5., 8.]],

#          [[7., 1.],
#           [6., 1.],
#           [1., 7.]]]])

so also not your expected output, I’m thus unsure how your approach worked.

In any case, torch.gather should work:

A = torch.tensor([[8, 3, 5],
                  [7, 6, 1],
                  [2, 4, 9]])
idxs = torch.tensor([[0, 2],
                     [1, 2],
                     [2, 0]])
torch.gather(A, 1, idxs)
# tensor([[8, 5],
#         [6, 1],
#         [9, 2]])
2 Likes