I have a 3D tensor A of shape (M, N, K) where M is the batch size, N is sequence length, and K is embedding dimension. I also have a list of indices along dimension 1 (i.e., indices for sequence length dimension). I am unable to figure out how to get elements of A based on this indexing.
For ex- M = 2 , N = 4 , K = 2
A = [[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]]]
Index = [0,2]
output = [[1,2] ,[13,14]]
Note that for each element of the batch, I have a different index for sequence. Basically, I want to extract embeddings of <\s> token. And for each sample in the batch, this <\s> can lie at a different index.
2 Likes
This should work:
A = torch.tensor([[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]]])
index = torch.tensor([0,2])
out = A[torch.arange(A.size(0)), index]
print(out)
> tensor([[ 1, 2],
[13, 14]])
1 Like
@ptrblck, It worked. Thanks a lot.
It would be very helpful if you could explain the working of this. I am not getting how torch.arange is working here.
torch.arange(A.size(0))
will return tensor([0, 1])
in this case, so the indexing will work as:
out[0] = A[0, index[0]]
out[1] = A[1, index[1]]
and will thus yield the desired results. On the other hand using A[:, index]
will not work (you might be more familiar with this indexing operation), as it would use index
on “all” values on dim0, so it’ll return:
A[:, index]
> tensor([[[ 1, 2],
[ 5, 6]],
[[ 9, 10],
[13, 14]]])
As you can see, index
(containing [0, 2]
) was now applied to all values in dim0
as:
A[0, index]
A[1, index]
thus returning more values.
EDIT: this numpy doc might be also helpful in case my explanation is confusing.
1 Like
Understood. Thanks a lot.
@ptrblck , How can I do it if I have indices like-
index = torch.tensor([[0,2],[1,3] , [0,1]]) ?
Basically, for each batch element, I have a different list of indices.
Could you post the desired output using this new index
and the previously defined A
tensor?
This should work:
A = torch.tensor([[[1,2],[3,4],[5,6],[7,8]] , [[9,10],[11,12],[13,14],[15,16]] , [[21,22],[23,24] ,[25,26],[27,28]]])
index = torch.tensor([[0,2], [1,3], [0,1]])
A[torch.arange(A.size(0)).unsqueeze(1), index]
> tensor([[[ 1, 2],
[ 5, 6]],
[[11, 12],
[15, 16]],
[[21, 22],
[23, 24]]])
1 Like