I have a 3D tensor A of shape (M, N, K) where M is the batch size, N is sequence length, and K is embedding dimension. I also have a list of indices along dimension 1 (i.e., indices for sequence length dimension). I am unable to figure out how to get elements of A based on this indexing.

For ex- M = 2 , N = 4 , K = 2

A = [[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]]]

Index = [0,2]

output = [[1,2] ,[13,14]]

Note that for each element of the batch, I have a different index for sequence. Basically, I want to extract embeddings of <\s> token. And for each sample in the batch, this <\s> can lie at a different index.

1 Like

This should work:

```
A = torch.tensor([[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]]])
index = torch.tensor([0,2])
out = A[torch.arange(A.size(0)), index]
print(out)
> tensor([[ 1, 2],
[13, 14]])
```

1 Like

@ptrblck, It worked. Thanks a lot.

It would be very helpful if you could explain the working of this. I am not getting how torch.arange is working here.

`torch.arange(A.size(0))`

will return `tensor([0, 1])`

in this case, so the indexing will work as:

```
out[0] = A[0, index[0]]
out[1] = A[1, index[1]]
```

and will thus yield the desired results. On the other hand using `A[:, index]`

will not work (you might be more familiar with this indexing operation), as it would use `index`

on â€śallâ€ť values on dim0, so itâ€™ll return:

```
A[:, index]
> tensor([[[ 1, 2],
[ 5, 6]],
[[ 9, 10],
[13, 14]]])
```

As you can see, `index`

(containing `[0, 2]`

) was now applied to all values in `dim0`

as:

```
A[0, index]
A[1, index]
```

thus returning more values.

EDIT: this numpy doc might be also helpful in case my explanation is confusing.

1 Like

Understood. Thanks a lot.

@ptrblck , How can I do it if I have indices like-

index = torch.tensor([[0,2],[1,3] , [0,1]]) ?

Basically, for each batch element, I have a different list of indices.

Could you post the desired output using this new `index`

and the previously defined `A`

tensor?

This should work:

```
A = torch.tensor([[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]] , [[21,22],[23,24] ,[25,26],[27,28]]])
index = torch.tensor([[0,2], [1,3], [0,1]])
A[torch.arange(A.size(0)).unsqueeze(1), index]
> tensor([[[ 1, 2],
[ 5, 6]],
[[11, 12],
[15, 16]],
[[21, 22],
[23, 24]]])
```

1 Like