How to select particular elements of a 3D tensor based on indices along dim 1 in PyTorch?

I have a 3D tensor A of shape (M, N, K) where M is the batch size, N is sequence length, and K is embedding dimension. I also have a list of indices along dimension 1 (i.e., indices for sequence length dimension). I am unable to figure out how to get elements of A based on this indexing.

For ex- M = 2 , N = 4 , K = 2
A = [[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]]]
Index = [0,2]

output = [[1,2] ,[13,14]]

Note that for each element of the batch, I have a different index for sequence. Basically, I want to extract embeddings of <\s> token. And for each sample in the batch, this <\s> can lie at a different index.

1 Like

This should work:

A = torch.tensor([[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]]])
index = torch.tensor([0,2])

out = A[torch.arange(A.size(0)), index]
print(out)
> tensor([[ 1,  2],
          [13, 14]])
1 Like

@ptrblck, It worked. Thanks a lot.
It would be very helpful if you could explain the working of this. I am not getting how torch.arange is working here.

torch.arange(A.size(0)) will return tensor([0, 1]) in this case, so the indexing will work as:

out[0] = A[0, index[0]]
out[1] = A[1, index[1]]

and will thus yield the desired results. On the other hand using A[:, index] will not work (you might be more familiar with this indexing operation), as it would use index on “all” values on dim0, so it’ll return:

A[:, index]
> tensor([[[ 1,  2],
           [ 5,  6]],

          [[ 9, 10],
           [13, 14]]])

As you can see, index (containing [0, 2]) was now applied to all values in dim0 as:

A[0, index]
A[1, index]

thus returning more values.

EDIT: this numpy doc might be also helpful in case my explanation is confusing.

1 Like

Understood. Thanks a lot.

@ptrblck , How can I do it if I have indices like-

index = torch.tensor([[0,2],[1,3] , [0,1]]) ?

Basically, for each batch element, I have a different list of indices.

Could you post the desired output using this new index and the previously defined A tensor?

A = torch.tensor([[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]] , [[21,22],[23,24] ,[25,26],[27,28]]])
index = torch.tensor([[0,2],[1,3] , [0,1]])
output = [[[1,2] , [5,6]] ,[[11,12] , [15,16]] , [[21,22] , [23,24]]]

This should work:

A = torch.tensor([[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]] , [[21,22],[23,24] ,[25,26],[27,28]]])
index = torch.tensor([[0,2], [1,3], [0,1]])
A[torch.arange(A.size(0)).unsqueeze(1), index]
> tensor([[[ 1,  2],
           [ 5,  6]],

          [[11, 12],
           [15, 16]],

          [[21, 22],
           [23, 24]]])
1 Like