# Multi-dimension torch.take (tf.gather_nd)

Hi all,
I tried some of the available API in Pytorch but I think none of them meet my requirement.
The most similar API in Pytorch is torch.take but the input is in 1D. However, tf.gather_nd probably could work but I couldn’t find the same function in Pytorch.

I have a 3D tensor in shape (B, H, K) as indices.
I would like to extract tensor from a map in shape (B, N, M) with the indices above.
The desired output should be in shape (B, H, KM).
All the values in the indices tensor are within [0, N].
That means every element in the indices is corresponding to one vector with shape (M) from the map and I want to query it for H * K times.
The operation should run in batch mode. So the final output should be in shape (B, H, KM).
The easiest way to solve this task is to use double for-loop. I’m wondering if there’s another way to solve it without for-loop.
Here’s the example:

``````# indices (B, H, K)
# map (B, N, M)
# for H
for j in range(H):
k_block = []

# for K
for i in range(K):
# b, 1
index = indices[:, j, i]
k_block.append(map[range(B), index])

# b, KM
k_block = torch.cat(k_block, dim=-1)
...
``````

Indexing the tensor might work.
I assume the initialization of `k_block` should be outside the first loop?
Otherwise, you would throw away the first results.

Based on this assumption, this code should work:

``````# Setup
B, H, K = 2, 3, 4
N, M = 5, 6

indices = torch.randint(0, N, (B, H, K))
x = torch.randn(B, N, M)

k_block = []
for j in range(H):

# for K
for i in range(K):
# b, 1
index = indices[:, j, i]
k_block.append(x[range(B), index])

# b, KM
k_block = torch.cat(k_block, dim=-1)

# Indexing approach
x1 = x.clone()
res = x1[torch.arange(B).unsqueeze(1).unsqueeze(2), indices].view(B, -1)

# Test
print((k_block - res).abs().max())
> tensor(0.)
``````

Hi ptrblck,