Good morning,
I have as an input of a layer a Tensor of points of the shape : [B,N,C]
and a Tensor of Ids to keep of the shape : [B,D0,…Dn]
What would be an efficient way to select the Points corresponding to the Ids ?
My current implementation is struggling with High number of points :
device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :]
Also Indices can be of the shape : [B,N,M]
I changed the code when indices are of the shape : [B,D0,…Dn] to :
new_points = torch.cat([points.index_select(1,idx[b]) for b in range(0,idx.shape[0])], dim=0)
return new_points
Improving the time of execution but can’t figure out how to do it for indices of shape [B,N,M]