Indexing 3D Tensor using 3D Index

I also have the same question.
To clarify, assuming I have a tensor A which A.shape = 2 x 3 x 2 (B x N x C)

A = torch.tensor(
[[[ 0.4930, 0.1150],
[-0.2355, 1.1917],
[-1.2421, -0.4383]],
[[ 0.3099, 3.4751],
[ 0.7780, 1.0990],
[-0.0795, 0.1633]]])

And an indices tensor idx where idx.shape = 2 x 3 x 1 (B x N x K)

idx = tensor([[[0],
[1],
[2]],
[[1],
[2],
[1]]])

I would like to get a 4D tensor res that res.shape = 2 x 3 x 1 x 2 (B x N x K x C)

res = tensor(
[[[[ 0.4930, 0.1150]],
[[-0.2355, 1.1917]],
[[-1.2421, -0.4383]]],
[[[ 0.7780, 1.0990]],
[[-0.0795, 0.1633]],
[[ 0.7780, 1.0990]]]])

Currently, I implemented it

res = torch.cat([A[i, idx[i]] for i in range(B)], dim=0).view(B, N, K, C)

I am wondering are there more efficient ways to do this? (e.g. Get rid of the for loop)
Thanks in advance!