How much we can speedup this KNN operation with C++ API in CPU?

I have the following KNN operation, it takes a 3D tensor with shape (B, N, D) as the input, where B is the batch size, N is the number of points, and D is the dimension of each point. What it does is basically computing the pairwise distance across the third dimension and returns a 4D tensor consisting of k-nearest neighbors. The entire operation should be differentiable to enable autograd.

A code snippet is attached as follows:

def knn(x, k):
    B, N, D = x.shape

    cdist = torch.cdist(x, x)
    _, indices = cdist.topk(k, dim=-1, largest=False, sorted=False)

    indices = indices + N * torch.arange(B).view(-1, 1, 1).expand_as(indices)
    x = x.view(B * N, D)[indices]  # shape: [B, N, k, D]
    return x

I am wondering what can be the efficiency of this implementation. How much speedup we may expect if I re-implement this KNN operation at a lower level with C++ API?

I don’t think you will see a significant speedup, if you use libtorch, but please update this thread in case you decide to try it out.
Since you are mostly using PyTorch methods in your code, there shouldn’t be a lot of Python overhead.