How to compute the K-nn graph of a tensor using pytorch?

I have a tensor say,

a = torch.random(10,2)

I would like to create a knn graph of this tensor a using torch such that it returns me k indices and distances, for each row of this a tensor. That is to say distances.shape is [10,k] and indices.shape is [10,k]

Basically I look to do something like this (section 1.6.1.1. Finding the Nearest Neighbors) using torch. It is not very evident to me how to do this?

One way, which I am able to arrive is:

dist = torch.sort(torch.cdist(a,a),dim=1)

Then assuming k=3

distances = dist.values[:,0:k]
indices = dist.indices[:,0:k]

This way I have what I wanted, but perhaps there are more efficient ways to achieve this?
PS: This method is highly memory inefficient is kd-tree calculation possible using torch?

Hi amitoz, I think the torch_cluster has a function you can directly call to compute the knn graph of a given torch tensor.

from torch_cluster import knn_graph

graph = knn_graph(a,k,loop=False)

Set loop=True if wish to include self-node in graph.