pts=torch.randn(2,1024,3) # [batch, point_number, dim]
You could directly index pts
via:
pts=torch.randn(2,1024,3) # [batch, point_number, dim]
qrs=torch.randn(2,512,3)
D=torch.cdist(qrs, pts)
K=64 #number of neighbors
dist, idx = D.topk(k=K, dim=-1, largest=False)
nn_pts = torch.stack([pts[n][i,:] for n, i in enumerate(torch.unbind(idx, dim = 0))], dim = 0) # [batch, queries_number, K, dim]
ret = pts[torch.arange(pts.size(0)).unsqueeze(1).unsqueeze(2), idx]
print((ret == nn_pts).all())
> tensor(True)
1 Like