I have a Tensor `X`

of the shape `B x N x D`

, a Batch of vectors of N points in D-Dimensional space.

I also have indices `idx`

of the shape `B x N x K`

where `K < N`

. In each element of `idx`

there is an integer between 0 and `N`

. I want to index `X`

using `idx`

such that I get a Tensor of the shape `B x N x K x D`

.

Is this possible? In Tensorflow, I would use `tf.gather`

as it works across many dims, but `torch.gather`

only works in one dimension, if I understand the docs correctly.

Minimal Example

```
import torch
import numpy as np
B = 64
N = 128
D = 3
K = 20
X_np = np.random.normal(size=(B, N, D))
X = torch.from_numpy(X_np)
idx_np = np.random.randint(0, N, size=(B, N, K))
idx = torch.from_numpy(idx_np)
# What I want to do:
# index X using idx, such that I have B,N,K,D
# For-Loop Numpy Version
result = np.zeros((B, N, K, D))
for b in range(B):
for n in range(N):
for k in range(K):
i = idx[b, n, k]
result[b, n, k] = X[b, i]
# My pytorch attempt so far: (works by reshaping X to 2 dimensions and keeping an index offset helper)
idx_helper = torch.arange(B) * N
idx_helper = torch.reshape(idx_helper, [B, 1, 1])
X_flat = torch.reshape(X, [-1, D])
# Line below works in Tensorflow!! How do I do this in PyTorch??
result = torch.gather(X, idx+idx_helper)
```

Any help would be appreciated!