Supose i have a tensor A of size `batch_size x num_class x Dim`

, and a batch of labels L of size `batch_size`

, where each element specifies which number in the second dim to choose from. And i want to slice tensor A using L, resulting a tensor: `bz x Dim`

, my question is, is there a good way to implement this. can anyone help me out?

2 Likes

This is similar to Indexing Multi-dimensional Tensors based on 1D tensor of indices

I guess what you need is :

```
A.gather(1, L.view(-1, 1, 1).expand(A.size(0), 1, A.size(2)))
```

I haven’t found anything better for now…

5 Likes

thank you very much for your solution.

**Update**:

Another way to do this now is:

```
A[torch.arange(A.size(0)), L]
```

Also this is much faster than the previous answer (between 10x to 100x).

2 Likes