I have a Tensor X of the shape B x N x D, a Batch of vectors of N points in D-Dimensional space.
I also have indices idx of the shape B x N x K where K < N. In each element of idx there is an integer between 0 and N. I want to index X using idx such that I get a Tensor of the shape B x N x K x D.
Is this possible? In Tensorflow, I would use tf.gather as it works across many dims, but torch.gather only works in one dimension, if I understand the docs correctly.
Minimal Example
import torch
import numpy as np
B = 64
N = 128
D = 3
K = 20
X_np = np.random.normal(size=(B, N, D))
X = torch.from_numpy(X_np)
idx_np = np.random.randint(0, N, size=(B, N, K))
idx = torch.from_numpy(idx_np)
# What I want to do:
# index X using idx, such that I have B,N,K,D
# For-Loop Numpy Version
result = np.zeros((B, N, K, D))
for b in range(B):
for n in range(N):
for k in range(K):
i = idx[b, n, k]
result[b, n, k] = X[b, i]
# My pytorch attempt so far: (works by reshaping X to 2 dimensions and keeping an index offset helper)
idx_helper = torch.arange(B) * N
idx_helper = torch.reshape(idx_helper, [B, 1, 1])
X_flat = torch.reshape(X, [-1, D])
# Line below works in Tensorflow!! How do I do this in PyTorch??
result = torch.gather(X, idx+idx_helper)
Any help would be appreciated!