Need help in gathering vectors from one 3d tensor based on indices given in a 3d tensor

A = torch.randn(5, 21, 64) # 5 batches, 21 nodes in each batch, each node represented by 64 dim vector
indices = torch.randint(low=0, high=21, size=(5, 21, 12)) # contains 12 indices per batch per node indices to select the 64 dim vector from correspondingly 

out = gather(A, indices) # output size = (5, 21, 12, 64)

Please help me with the above functionality. Thanks

One way I found was

A.view(-1, 64)[indices.view(-1, 12), :].view(5, 21, 12, 64)

Any cleaner way of doing this?

Are you expecting the output size to be [5, 21, 12, 64] or rather [5, 21, 12]?
In the latter case, you can use gather:

out = A.gather(2, indices)

I need it to be [5, 21, 12, 64], I want to choose the vectors of dim 64 given the 12 indices