Gather with batch size

Hello, say I have a tensor c of 3D coordinates with shape (B x P x 3), where B is the batch size, and another one with indexes idx and shape (B x I x 2):

c = torch.tensor([[[1,1,2],[5,2,3],[6,1,8]]])
idx = torch.tensor([[[0,1],[1,2]]])

(Here the batch size is 1, just for simplicity).
I want to slice the coordinates according to the idx tensor, so the result should be a tensor with shape (B x I x P x 3) and values

result[:,0] = torch.tensor([[1,1,2],[5,2,3]])
result[:,1] = torch.tensor([[5,2,3],[6,1,8]])

I have tried using slicing c[:,idx], but that only works apparently when B=1, although it adds dimensions I don’t want. I also tried with torch.gather(c, dim=1, index=idx). What is the way to do this?

Thanks!

1 Like