I have a Tensor
X of the shape
B x N x D, a Batch of vectors of N points in D-Dimensional space.
I also have indices
idx of the shape
B x N x K where
K < N. In each element of
idx there is an integer between 0 and
N. I want to index
idx such that I get a Tensor of the shape
B x N x K x D.
Is this possible? In Tensorflow, I would use
tf.gather as it works across many dims, but
torch.gather only works in one dimension, if I understand the docs correctly.
import torch import numpy as np B = 64 N = 128 D = 3 K = 20 X_np = np.random.normal(size=(B, N, D)) X = torch.from_numpy(X_np) idx_np = np.random.randint(0, N, size=(B, N, K)) idx = torch.from_numpy(idx_np) # What I want to do: # index X using idx, such that I have B,N,K,D # For-Loop Numpy Version result = np.zeros((B, N, K, D)) for b in range(B): for n in range(N): for k in range(K): i = idx[b, n, k] result[b, n, k] = X[b, i] # My pytorch attempt so far: (works by reshaping X to 2 dimensions and keeping an index offset helper) idx_helper = torch.arange(B) * N idx_helper = torch.reshape(idx_helper, [B, 1, 1]) X_flat = torch.reshape(X, [-1, D]) # Line below works in Tensorflow!! How do I do this in PyTorch?? result = torch.gather(X, idx+idx_helper)
Any help would be appreciated!