So let’s say I have array A
and the indexing array idxs
as follows
A = torch.Tensor([[8, 3, 5],
[7, 6, 1],
[2, 4, 9]])
idxs = torch.Tensor([[0, 2],
[1, 2],
[2, 0]]).int()
I want to find a fast method of indexing where the first row of idxs
indexes the first row of A
, the second row of idxs
indexes the second row, and so on. In this way I’d have the following output
[[8, 5],
[6, 1],
[9, 2]]
I have a solution as shown below
A[:, idxs][torch.eye(len(A))]
Which works quite well, but I am worried that if A
has tens of thousands of rows there will be memory issues due to the torch.eye
function.
- Is pytorch’s storage of tensors sufficiently efficient that there should be no issues at all? Or,
- Should I be worried about this and change my solution to something better?
- Regardless of both questions, is there a better indexing method than what I have shown?
Thanks in advance!