Supose i have a tensor A of size batch_size x num_class x Dim
, and a batch of labels L of size batch_size
, where each element specifies which number in the second dim to choose from. And i want to slice tensor A using L, resulting a tensor: bz x Dim
, my question is, is there a good way to implement this. can anyone help me out?
2 Likes
This is similar to Indexing Multi-dimensional Tensors based on 1D tensor of indices
I guess what you need is :
A.gather(1, L.view(-1, 1, 1).expand(A.size(0), 1, A.size(2)))
I haven’t found anything better for now…
5 Likes
thank you very much for your solution.
Update:
Another way to do this now is:
A[torch.arange(A.size(0)), L]
Also this is much faster than the previous answer (between 10x to 100x).
2 Likes