A = torch.randn(5, 21, 64) # 5 batches, 21 nodes in each batch, each node represented by 64 dim vector
indices = torch.randint(low=0, high=21, size=(5, 21, 12)) # contains 12 indices per batch per node indices to select the 64 dim vector from correspondingly
out = gather(A, indices) # output size = (5, 21, 12, 64)
Please help me with the above functionality. Thanks
One way I found was
A.view(-1, 64)[indices.view(-1, 12), :].view(5, 21, 12, 64)
Any cleaner way of doing this?