Hi all,

I tried some of the available API in Pytorch but I think none of them meet my requirement.

The most similar API in Pytorch is torch.take but the input is in 1D. However, tf.gather_nd probably could work but I couldn’t find the same function in Pytorch.

Here’s the task:

I have a 3D tensor in shape (B, H, K) as indices.

I would like to extract tensor from a map in shape (B, N, M) with the indices above.

The desired output should be in shape (B, H, KM).

All the values in the indices tensor are within [0, N].

That means every element in the indices is corresponding to one vector with shape (M) from the map and I want to query it for H * K times.

The operation should run in batch mode. So the final output should be in shape (B, H, KM).

The easiest way to solve this task is to use double for-loop. I’m wondering if there’s another way to solve it without for-loop.

Here’s the example:

```
# indices (B, H, K)
# map (B, N, M)
# for H
for j in range(H):
k_block = []
# for K
for i in range(K):
# b, 1
index = indices[:, j, i]
k_block.append(map[range(B), index])
# b, KM
k_block = torch.cat(k_block, dim=-1)
...
```