Indexing 3D Tensor using 3D Index

I have a Tensor X of the shape B x N x D, a Batch of vectors of N points in D-Dimensional space.

I also have indices idx of the shape B x N x K where K < N. In each element of idx there is an integer between 0 and N. I want to index X using idx such that I get a Tensor of the shape B x N x K x D.

Is this possible? In Tensorflow, I would use tf.gather as it works across many dims, but torch.gather only works in one dimension, if I understand the docs correctly.

Minimal Example

import torch
import numpy as np

B = 64
N = 128
D = 3
K = 20

X_np = np.random.normal(size=(B, N, D))
X = torch.from_numpy(X_np)

idx_np = np.random.randint(0, N, size=(B, N, K))
idx = torch.from_numpy(idx_np)

# What I want to do:
# index X using idx, such that I have B,N,K,D

# For-Loop Numpy Version
result = np.zeros((B, N, K, D))
for b in range(B):
    for n in range(N):
        for k in range(K):
            i = idx[b, n, k]
            result[b, n, k] = X[b, i]

# My pytorch attempt so far: (works by reshaping X to 2 dimensions and keeping an index offset helper)
idx_helper = torch.arange(B) * N
idx_helper = torch.reshape(idx_helper, [B, 1, 1])
X_flat = torch.reshape(X, [-1, D])
# Line below works in Tensorflow!! How do I do this in PyTorch??
result = torch.gather(X, idx+idx_helper)

Any help would be appreciated!

I also have the same question.
To clarify, assuming I have a tensor A which A.shape = 2 x 3 x 2 (B x N x C)

A = torch.tensor(
[[[ 0.4930, 0.1150],
[-0.2355, 1.1917],
[-1.2421, -0.4383]],
[[ 0.3099, 3.4751],
[ 0.7780, 1.0990],
[-0.0795, 0.1633]]])

And an indices tensor idx where idx.shape = 2 x 3 x 1 (B x N x K)

idx = tensor([[[0],
[1],
[2]],
[[1],
[2],
[1]]])

I would like to get a 4D tensor res that res.shape = 2 x 3 x 1 x 2 (B x N x K x C)

res = tensor(
[[[[ 0.4930, 0.1150]],
[[-0.2355, 1.1917]],
[[-1.2421, -0.4383]]],
[[[ 0.7780, 1.0990]],
[[-0.0795, 0.1633]],
[[ 0.7780, 1.0990]]]])

Currently, I implemented it

res = torch.cat([A[i, idx[i]] for i in range(B)], dim=0).view(B, N, K, C)

I am wondering are there more efficient ways to do this? (e.g. Get rid of the for loop)
Thanks in advance!