For example, I have a tensor a
of shape [k, m, n]
and tensor b
fo shape [k, ]
. What I want to do is to get a[torch.arange(k), b]
. Is there any api provided by PyTorch can directly deal with it? Creating a new tensor using torch.arange(k)
doesn’t look that elegant.
Hi,
I’m confused what you expect to get here. what does b contain?
Can you share a small code sample that shows what you expect as output?
Sorry, I should have included more information here. Every element in b
is in [0, m)
. You can imagine that we have a sequence ranking task. k is the batch size, and for each batch instance, we have m candidate sequences of length n. So now b
is the index of candidates, and there are k of them, because we want to choose one candidate for every batch instance. I think there should be some more straightforward way to handle this right? Since it kinda seems like a common need.
Yes, you can do: a.gather(1, b.view(k, 1, 1).expand(k, 1, n))
will give you what you want
Basically, use the gather function, and since you want all the elements over the last dimension of size n, you just expand the Tensor of indices there.