How to select from a 3D tensor using mask of 1D

Supose i have a tensor A of size batch_size x num_class x Dim, and a batch of labels L of size batch_size, where each element specifies which number in the second dim to choose from. And i want to slice tensor A using L, resulting a tensor: bz x Dim, my question is, is there a good way to implement this. can anyone help me out?

2 Likes

This is similar to Indexing Multi-dimensional Tensors based on 1D tensor of indices

I guess what you need is :

A.gather(1, L.view(-1, 1, 1).expand(A.size(0), 1, A.size(2)))

I haven’t found anything better for now…

5 Likes

thank you very much for your solution.

Update:

Another way to do this now is:

A[torch.arange(A.size(0)), L]

Also this is much faster than the previous answer (between 10x to 100x).

2 Likes