I tried some of the available API in Pytorch but I think none of them meet my requirement.
The most similar API in Pytorch is torch.take but the input is in 1D. However, tf.gather_nd probably could work but I couldn’t find the same function in Pytorch.
Here’s the task:
I have a 3D tensor in shape (B, H, K) as indices.
I would like to extract tensor from a map in shape (B, N, M) with the indices above.
The desired output should be in shape (B, H, KM).
All the values in the indices tensor are within [0, N].
That means every element in the indices is corresponding to one vector with shape (M) from the map and I want to query it for H * K times.
The operation should run in batch mode. So the final output should be in shape (B, H, KM).
The easiest way to solve this task is to use double for-loop. I’m wondering if there’s another way to solve it without for-loop.
Here’s the example:
# indices (B, H, K) # map (B, N, M) # for H for j in range(H): k_block =  # for K for i in range(K): # b, 1 index = indices[:, j, i] k_block.append(map[range(B), index]) # b, KM k_block = torch.cat(k_block, dim=-1) ...