Selecting from a 2d tensor with rows of column indexes

I have a 2d input tensor and a 2d index tensor with the same number of rows where each value represents the column index I want to select.

I want as output a 2d tensor with the same number of rows as the input tensor and cols as the index tensor.

I can do this by creating row and column index tensors, selecting from the input tensor, and then reshaping, but is there a more efficient way to do this?


In [2]: a = torch.arange(12).view(3,4);a
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

In [3]: idx = torch.tensor([[1,3],[0,1],[2,3]]);idx #all elements are <4, the number of columns of a
tensor([[1, 3], #want element 1,3 of row 0
        [0, 1], # want element 0, 1 of row 1
        [2, 3]]) #want element 2, 3 of row 2

In [4]: row_idx = torch.arange(3).repeat_interleave(2) 
   ...: col_idx = idx.view(-1)
   ...: (row_idx, col_idx)
Out[4]: (tensor([0, 0, 1, 1, 2, 2]), tensor([1, 3, 0, 1, 2, 3]))

In [5]: a[row_idx, col_idx].view(3, -1)
tensor([[ 1,  3],
        [ 4,  5],
        [10, 11]])

Direct indexing with torch.arange should work:

a[torch.arange(a.size(0)).unsqueeze(1), idx]
# tensor([[ 1,  3],
#         [ 4,  5],
#         [10, 11]])

Ah thanks!

To see if I understand correctly, what’s going on is:

  1. torch.arange(a.size(0)).unsqueeze(1) returns a 3x1 tensor of [[0,1,2]]
  2. that gets broadcast to the shape of idx, (3x2) by ‘stretching’ (numpy’s term) in the 2nd dimension. call this r_idx
  3. The i,jth element of the output is the element of a in the r_idx[i,j]th row and idx[i,j]th col. That is, out[i,j] = a[r_idx[i], idx[j]]

Is my understanding correct?

Where can I find this in the pytorch docs? I have been combing over them, particularly this section on indexing and slicing but haven’t found anything on how Tensor.__getitem__ works.

I think your explanation is correct.
I’m unsure if this behavior is explicitly documented and think it’s cloned from numpy’s advanced indexing logic. These docs explain the numpy indexing in more detail.

Great, thank you!

I see in the Tensor view docs, there’s a note saying

When accessing the contents of a tensor via indexing, PyTorch follows Numpy behaviors that basic indexing returns views, while advanced indexing returns a copy. Assignment via either basic or advanced indexing is in-place. See more examples in Numpy indexing documentation.

I will open an issue on github to request that this is made explicit in the indexing/slicing docs.