Let’s say I have a tensor x, size of [B, N, C] and tensor idx, size of [B,N,D]. ( C>D) . Tensor B has indices that I need for cropping. For example :
x shape: torch.Size([2, 3, 3])
tensor([[[0.0734, 0.9977, 0.8373],
[0.3425, 0.1120, 0.7120],
[0.0164, 0.8370, 0.9897]],
[[0.5219, 0.7202, 0.0872],
[0.9454, 0.1704, 0.9376],
[0.8468, 0.5296, 0.0180]]])
idx shape torch.Size([2, 3, 2])
tensor([[[1, 2],
[2, 1],
[1, 2]],
[[1, 2],
[1, 0],
[0, 2]]], dtype=torch.int32)
what I want:
tensor([[[ 0.9977, 0.8373],
[0.7120, 0.1120],
[0.8370, 0.9897]],
[[0.7202, 0.0872],
[0.1704, 0.9454],
[0.8468 0.0180]]])
How can I do this in an efficient way?