How to find K-nearest neighbor of a tensor

Hi. I have a tensor of shape (Batch, 40, 128, 3). How can I find the k nearest neighbor of a given constant data point that is 3-dimension, so that I get a tensor of shape (Batch, 40, k, 3)? Thanks

It’s a bit late but since I went throw the same problem here’s the answer :tipping_hand_man:t3:

def kNN(cloud, center, k):
    center = center.expand(cloud.shape)
    
    # Computing euclidean distance
    dist = cloud.add( - center).pow(2).sum(dim=3).pow(.5)
    
    # Getting the k nearest points
    knn_indices = dist.topk(k, largest=False, sorted=False)[1]
    
    return cloud.gather(2, knn_indices.unsqueeze(-1).repeat(1,1,1,3))

where cloud is a 4 dimension tensor and center is a tensor with the three coordinates x-y-z.

The function is based on a reply by @ptrblck in K nearest neighbor in pytorch

1 Like

Late as well haha but for anyone who might still be looking, I implemented NN, KNN and KMeans on a project I am working on only using PyTorch. You can find the implementation here with an example: Nearest Neighbor, K Nearest Neighbor and K Means (NN, KNN, KMeans) only using PyTorch · GitHub

>>> import torch as th
>>> from clustering import KNN
>>> data = th.Tensor([[1, 1], [0.88, 0.90], [-1, -1], [-1, -0.88]])
>>> labels = th.LongTensor([3, 3, 5, 5])
>>> test = th.Tensor([[-0.5, -0.5], [0.88, 0.88]])
>>> knn = KNN(data, labels)
>>> knn(test)
tensor([5, 3])

This is an updated link for the project (Nearest Neighbor, K Nearest Neighbor and K Means (NN, KNN, KMeans) implemented only using PyTorch · GitHub ). It is under a new username.