Cropping a tensor according to a tensor which contains cropped indices

Let’s say I have a tensor x, size of [B, N, C] and tensor idx, size of [B,N,D]. ( C>D) . Tensor B has indices that I need for cropping. For example :

``````x shape:  torch.Size([2, 3, 3])
tensor([[[0.0734, 0.9977, 0.8373],
[0.3425, 0.1120, 0.7120],
[0.0164, 0.8370, 0.9897]],

[[0.5219, 0.7202, 0.0872],
[0.9454, 0.1704, 0.9376],
[0.8468, 0.5296, 0.0180]]])

idx shape  torch.Size([2, 3, 2])
tensor([[[1, 2],
[2, 1],
[1, 2]],

[[1, 2],
[1, 0],
[0, 2]]], dtype=torch.int32)

what I want:

tensor([[[ 0.9977, 0.8373],
[0.7120,  0.1120],
[0.8370, 0.9897]],

[[0.7202, 0.0872],
[0.1704, 0.9454],
[0.8468 0.0180]]])
``````

How can I do this in an efficient way?

I think I have found a way. I am posting it here in case somebody needs it.

``````bidx = torch.arange(B).view(B,1,1)
bidx = bidx.expand_as(idx).contiguous()
nidx = torch.arange(N).view(1,N,1)
nidx = nidx.expand_as(idx).contiguous()
x = x[bidx.view(-1), nidx.view(-1), idx.view(-1)].view(B,N,D)
``````