How to find K-nearest neighbor of a tensor

It’s a bit late but since I went throw the same problem here’s the answer :tipping_hand_man:t3:

def kNN(cloud, center, k):
    center = center.expand(cloud.shape)
    
    # Computing euclidean distance
    dist = cloud.add( - center).pow(2).sum(dim=3).pow(.5)
    
    # Getting the k nearest points
    knn_indices = dist.topk(k, largest=False, sorted=False)[1]
    
    return cloud.gather(2, knn_indices.unsqueeze(-1).repeat(1,1,1,3))

where cloud is a 4 dimension tensor and center is a tensor with the three coordinates x-y-z.

The function is based on a reply by @ptrblck in K nearest neighbor in pytorch

1 Like