Is there any better way to do this kind of indexing?

For example, I have a tensor a of shape [k, m, n] and tensor b fo shape [k, ]. What I want to do is to get a[torch.arange(k), b]. Is there any api provided by PyTorch can directly deal with it? Creating a new tensor using torch.arange(k) doesn’t look that elegant.


I’m confused what you expect to get here. what does b contain?
Can you share a small code sample that shows what you expect as output?

Sorry, I should have included more information here. Every element in b is in [0, m). You can imagine that we have a sequence ranking task. k is the batch size, and for each batch instance, we have m candidate sequences of length n. So now b is the index of candidates, and there are k of them, because we want to choose one candidate for every batch instance. I think there should be some more straightforward way to handle this right? Since it kinda seems like a common need.

Yes, you can do: a.gather(1, b.view(k, 1, 1).expand(k, 1, n)) will give you what you want :slight_smile:
Basically, use the gather function, and since you want all the elements over the last dimension of size n, you just expand the Tensor of indices there.

1 Like