Approximate Nearest Neighbors layer

(Shai) #1

I have two collections of features (dim ~O(100), number of points ~O(10K)), ref and q.
I want to find k (approximate) nearest neighbors for each vector in q in the ref collection.
Is there a way to do it in pytorch?
I am trying to find a cuda implementation for k-NN search (kd-tree or LSH), but I can’t find anything I can plug in a pytorch layer.
Can anyone please point me at the right direction?