I am finding an efficient way to select and collapse tensor along specific axis.
Suppose and I have a Tensor
X whose shape is (N, L, D) and a list of ids
I whose length is N.
What I need is a Tensor
Y of shape (N, D) where
Y[n, :] = X[n, I[n], :] .
torch.gather can do this kind of thing like this.
idx = torch.LongTensor(N, 1, D) for n in range(N): idx[n, :, :] = I[n] Y = torch.gather(X, 1, idx)
However this code is obviously not efficient.
It creates a new tensor and fill it by for loop.
Are there any better way?