What is the fastest way to index the knn tensor?

I use the following codes to get the neighbor points. However, the speed is low. I use for loop to obain the desired neighbors. Is there more efficient way? torch.gather? or torch.take? How to index?

pts=torch.randn(2,1024,3) # [batch, point_number, dim]
qrs=torch.randn(2,512,3)
D=torch.cdist(qrs, pts)
K=64 #number of neighbors
dist, idx = D.topk(k=K, dim=-1, largest=False)
nn_pts = torch.stack([pts[n][i,:] for n, i in enumerate(torch.unbind(idx, dim = 0))], dim = 0) # [batch, queries_number, K, dim]

You could directly index pts via:

pts=torch.randn(2,1024,3) # [batch, point_number, dim]
qrs=torch.randn(2,512,3)
D=torch.cdist(qrs, pts)
K=64 #number of neighbors
dist, idx = D.topk(k=K, dim=-1, largest=False)
nn_pts = torch.stack([pts[n][i,:] for n, i in enumerate(torch.unbind(idx, dim = 0))], dim = 0) # [batch, queries_number, K, dim]


ret = pts[torch.arange(pts.size(0)).unsqueeze(1).unsqueeze(2), idx]
print((ret == nn_pts).all())
> tensor(True)
1 Like