Hi.
I am finding an efficient way to select and collapse tensor along specific axis.
Suppose and I have a Tensor X
whose shape is (N, L, D) and a list of ids I
whose length is N.
What I need is a Tensor Y
of shape (N, D) where Y[n, :] = X[n, I[n], :]
.
I found torch.gather
can do this kind of thing like this.
idx = torch.LongTensor(N, 1, D)
for n in range(N):
idx[n, :, :] = I[n]
Y = torch.gather(X, 1, idx)
However this code is obviously not efficient.
It creates a new tensor and fill it by for loop.
Are there any better way?
Thank you.