I have a 2d input tensor and a 2d index tensor with the same number of rows where each value represents the column index I want to select.
I want as output a 2d tensor with the same number of rows as the input tensor and cols as the index tensor.
I can do this by creating row and column index tensors, selecting from the input tensor, and then reshaping, but is there a more efficient way to do this?
Example:
In [2]: a = torch.arange(12).view(3,4);a
Out[2]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
In [3]: idx = torch.tensor([[1,3],[0,1],[2,3]]);idx #all elements are <4, the number of columns of a
...:
Out[3]:
tensor([[1, 3], #want element 1,3 of row 0
[0, 1], # want element 0, 1 of row 1
[2, 3]]) #want element 2, 3 of row 2
In [4]: row_idx = torch.arange(3).repeat_interleave(2)
...: col_idx = idx.view(-1)
...: (row_idx, col_idx)
Out[4]: (tensor([0, 0, 1, 1, 2, 2]), tensor([1, 3, 0, 1, 2, 3]))
In [5]: a[row_idx, col_idx].view(3, -1)
Out[5]:
tensor([[ 1, 3],
[ 4, 5],
[10, 11]])