It’s a bit late but since I went throw the same problem here’s the answer
def kNN(cloud, center, k):
center = center.expand(cloud.shape)
# Computing euclidean distance
dist = cloud.add( - center).pow(2).sum(dim=3).pow(.5)
# Getting the k nearest points
knn_indices = dist.topk(k, largest=False, sorted=False)[1]
return cloud.gather(2, knn_indices.unsqueeze(-1).repeat(1,1,1,3))
where cloud is a 4 dimension tensor and center is a tensor with the three coordinates x-y-z.
The function is based on a reply by @ptrblck in K nearest neighbor in pytorch