# How to select particular elements of a 3D tensor based on indices along dim 1 in PyTorch?

I have a 3D tensor A of shape (M, N, K) where M is the batch size, N is sequence length, and K is embedding dimension. I also have a list of indices along dimension 1 (i.e., indices for sequence length dimension). I am unable to figure out how to get elements of A based on this indexing.

For ex- M = 2 , N = 4 , K = 2
A = [[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]]]
Index = [0,2]

output = [[1,2] ,[13,14]]

Note that for each element of the batch, I have a different index for sequence. Basically, I want to extract embeddings of <\s> token. And for each sample in the batch, this <\s> can lie at a different index.

2 Likes

This should work:

``````A = torch.tensor([[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]]])
index = torch.tensor([0,2])

out = A[torch.arange(A.size(0)), index]
print(out)
> tensor([[ 1,  2],
[13, 14]])
``````
1 Like

@ptrblck, It worked. Thanks a lot.
It would be very helpful if you could explain the working of this. I am not getting how torch.arange is working here.

`torch.arange(A.size(0))` will return `tensor([0, 1])` in this case, so the indexing will work as:

``````out[0] = A[0, index[0]]
out[1] = A[1, index[1]]
``````

and will thus yield the desired results. On the other hand using `A[:, index]` will not work (you might be more familiar with this indexing operation), as it would use `index` on â€śallâ€ť values on dim0, so itâ€™ll return:

``````A[:, index]
> tensor([[[ 1,  2],
[ 5,  6]],

[[ 9, 10],
[13, 14]]])
``````

As you can see, `index` (containing `[0, 2]`) was now applied to all values in `dim0` as:

``````A[0, index]
A[1, index]
``````

thus returning more values.

EDIT: this numpy doc might be also helpful in case my explanation is confusing.

1 Like

Understood. Thanks a lot.

@ptrblck , How can I do it if I have indices like-

index = torch.tensor([[0,2],[1,3] , [0,1]]) ?

Basically, for each batch element, I have a different list of indices.

Could you post the desired output using this new `index` and the previously defined `A` tensor?

A = torch.tensor([[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]] , [[21,22],[23,24] ,[25,26],[27,28]]])
index = torch.tensor([[0,2],[1,3] , [0,1]])
output = [[[1,2] , [5,6]] ,[[11,12] , [15,16]] , [[21,22] , [23,24]]]

This should work:

``````A = torch.tensor([[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]] , [[21,22],[23,24] ,[25,26],[27,28]]])
index = torch.tensor([[0,2], [1,3], [0,1]])
A[torch.arange(A.size(0)).unsqueeze(1), index]
> tensor([[[ 1,  2],
[ 5,  6]],

[[11, 12],
[15, 16]],

[[21, 22],
[23, 24]]])
``````
1 Like