Selecting tensors subparts based on indices

Good morning,
I have as an input of a layer a Tensor of points of the shape : [B,N,C]
and a Tensor of Ids to keep of the shape : [B,D0,…Dn]

What would be an efficient way to select the Points corresponding to the Ids ?
My current implementation is struggling with High number of points :

 device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]

Also Indices can be of the shape : [B,N,M]
I changed the code when indices are of the shape : [B,D0,…Dn] to :

new_points = torch.cat([points.index_select(1,idx[b]) for b in range(0,idx.shape[0])], dim=0)
return new_points
Improving the time of execution but can’t figure out how to do it for indices of shape [B,N,M]

I’m not sure how your indexing works, since I wasn’t able to run your code.
Based on the shapes, would something like this work?

B, N, C, M = 2, 3, 4, 5
x = torch.randn(B, N, C)
idx = torch.randint(0, C, (B, N, M))
torch.gather(x, 2, idx)