I have the following KNN operation, it takes a 3D tensor with shape (B, N, D) as the input, where B is the batch size, N is the number of points, and D is the dimension of each point. What it does is basically computing the pairwise distance across the third dimension and returns a 4D tensor consisting of k-nearest neighbors. The entire operation should be differentiable to enable autograd.
A code snippet is attached as follows:
def knn(x, k):
B, N, D = x.shape
cdist = torch.cdist(x, x)
_, indices = cdist.topk(k, dim=-1, largest=False, sorted=False)
indices = indices + N * torch.arange(B).view(-1, 1, 1).expand_as(indices)
x = x.view(B * N, D)[indices] # shape: [B, N, k, D]
return x
I am wondering what can be the efficiency of this implementation. How much speedup we may expect if I re-implement this KNN operation at a lower level with C++ API?