Multi-dimension torch.take (tf.gather_nd)

Hi all,
I tried some of the available API in Pytorch but I think none of them meet my requirement.
The most similar API in Pytorch is torch.take but the input is in 1D. However, tf.gather_nd probably could work but I couldn’t find the same function in Pytorch.

Here’s the task:
I have a 3D tensor in shape (B, H, K) as indices.
I would like to extract tensor from a map in shape (B, N, M) with the indices above.
The desired output should be in shape (B, H, KM).
All the values in the indices tensor are within [0, N].
That means every element in the indices is corresponding to one vector with shape (M) from the map and I want to query it for H * K times.
The operation should run in batch mode. So the final output should be in shape (B, H, KM).
The easiest way to solve this task is to use double for-loop. I’m wondering if there’s another way to solve it without for-loop.
Here’s the example:

# indices (B, H, K)
# map (B, N, M)
# for H
for j in range(H):
    k_block = []

    # for K
    for i in range(K):
        # b, 1
        index = indices[:, j, i]
        k_block.append(map[range(B), index])

    # b, KM
    k_block =, dim=-1) 

Indexing the tensor might work.
I assume the initialization of k_block should be outside the first loop?
Otherwise, you would throw away the first results.

Based on this assumption, this code should work:

# Setup
B, H, K = 2, 3, 4
N, M = 5, 6

indices = torch.randint(0, N, (B, H, K))
x = torch.randn(B, N, M)

# Your approach
k_block = []
for j in range(H):

    # for K
    for i in range(K):
        # b, 1
        index = indices[:, j, i]
        k_block.append(x[range(B), index])

    # b, KM
k_block =, dim=-1) 

# Indexing approach
x1 = x.clone()
res = x1[torch.arange(B).unsqueeze(1).unsqueeze(2), indices].view(B, -1)

# Test
print((k_block - res).abs().max())
> tensor(0.)

Hi ptrblck,

Thank you for your fast reply!
It did solve my problem. Didn’t expect it can be such easy!
Thanks again.