Hi pytorch,
I was wondering if anyone had run across any code which finds the nearest neighbor of n points from a list of m points without replacement (n << m). so if point n1 and n2 are closest to point m1, the closer one is selected and the further one must recalculate the next nearest neighbor. Can anyone help?
Assuming 2-d Euclidean distance, X: (n, d)
, P: (m, d)
:
pairwise = (X.unsqueeze(1).expand(n, m, d) - P.unsqueeze(0).expand(n, m, d)).norm(p=2, dim=2)
_, rank = pairwise.view(-1).sort()
rank_x = rank / m
rank_p = rank % m
used = set()
nn = [None for _ in range(n)]
rest = n
for x, p in zip(rank_x, rank_p):
if p not in used and nn[x] is None:
nn[x] = p
used.add(p)
rest -= 1
if rest == 0:
break
nn = torch.Tensor(nn)
sorry for the ugly second half can’t think of better solutions.
1 Like
Thanks for the help. Unfortunately this is a lot slower than what I have right now, and I’m really not sure why.