K nearest neighbor in pytorch

After calculating the distance between your test sample and , you could probably use topk to get the nearest neighbors.
Would this work for you:

data = torch.randn(100, 10)
test = torch.randn(1, 10)

dist = torch.norm(data - test, dim=1, p=None)
knn = dist.topk(3, largest=False)

print('kNN dist: {}, index: {}'.format(knn.values, knn.indices))
13 Likes