Hi all,
I tried some of the available API in Pytorch but I think none of them meet my requirement.
The most similar API in Pytorch is torch.take but the input is in 1D. However, tf.gather_nd probably could work but I couldn’t find the same function in Pytorch.
Here’s the task:
I have a 3D tensor in shape (B, H, K) as indices.
I would like to extract tensor from a map in shape (B, N, M) with the indices above.
The desired output should be in shape (B, H, KM).
All the values in the indices tensor are within [0, N].
That means every element in the indices is corresponding to one vector with shape (M) from the map and I want to query it for H * K times.
The operation should run in batch mode. So the final output should be in shape (B, H, KM).
The easiest way to solve this task is to use double for-loop. I’m wondering if there’s another way to solve it without for-loop.
Here’s the example:
# indices (B, H, K)
# map (B, N, M)
# for H
for j in range(H):
k_block = []
# for K
for i in range(K):
# b, 1
index = indices[:, j, i]
k_block.append(map[range(B), index])
# b, KM
k_block = torch.cat(k_block, dim=-1)
...