# How to compute the K-nn graph of a tensor using pytorch?

I have a tensor say,

``````a = torch.random(10,2)
``````

I would like to create a knn graph of this tensor `a` using torch such that it returns me `k` indices and distances, for each row of this `a` tensor. That is to say `distances.shape` is `[10,k]` and `indices.shape` is `[10,k]`

Basically I look to do something like this (section 1.6.1.1. Finding the Nearest Neighbors) using torch. It is not very evident to me how to do this?

One way, which I am able to arrive is:

``````dist = torch.sort(torch.cdist(a,a),dim=1)
``````

Then assuming `k=3`

``````distances = dist.values[:,0:k]
indices = dist.indices[:,0:k]
``````

This way I have what I wanted, but perhaps there are more efficient ways to achieve this?
PS: This method is highly memory inefficient is `kd-tree` calculation possible using torch?

Hi amitoz, I think the torch_cluster has a function you can directly call to compute the knn graph of a given torch tensor.

``````from torch_cluster import knn_graph

graph = knn_graph(a,k,loop=False)
``````

Set loop=True if wish to include self-node in graph.