I have a tensor say,
a = torch.random(10,2)
I would like to create a knn graph of this tensor
a using torch such that it returns me
k indices and distances, for each row of this
a tensor. That is to say
Basically I look to do something like this (section 22.214.171.124. Finding the Nearest Neighbors) using torch. It is not very evident to me how to do this?
One way, which I am able to arrive is:
dist = torch.sort(torch.cdist(a,a),dim=1)
distances = dist.values[:,0:k]
indices = dist.indices[:,0:k]
This way I have what I wanted, but perhaps there are more efficient ways to achieve this?
PS: This method is highly memory inefficient is
kd-tree calculation possible using torch?
Hi amitoz, I think the torch_cluster has a function you can directly call to compute the knn graph of a given torch tensor.
from torch_cluster import knn_graph
graph = knn_graph(a,k,loop=False)
Set loop=True if wish to include self-node in graph.