Approximate Nearest Neighbors layer

I have two collections of features (dim ~O(100), number of points ~O(10K)), ref and q.
I want to find k (approximate) nearest neighbors for each vector in q in the ref collection.
Is there a way to do it in pytorch?
I am trying to find a cuda implementation for k-NN search (kd-tree or LSH), but I can’t find anything I can plug in a pytorch layer.
Can anyone please point me at the right direction?

Thanks!

At the momnet I found faiss package to be useful.
Here’s an example of using faiss with pytorch.

It’s not perfect (takes a lot of GPU memory, crashes every once in a while) but it does the job for me.