Nearest Neighbor without replacement

Hi pytorch,
I was wondering if anyone had run across any code which finds the nearest neighbor of n points from a list of m points without replacement (n << m). so if point n1 and n2 are closest to point m1, the closer one is selected and the further one must recalculate the next nearest neighbor. Can anyone help?

Assuming 2-d Euclidean distance, X: (n, d), P: (m, d):

pairwise = (X.unsqueeze(1).expand(n, m, d) - P.unsqueeze(0).expand(n, m, d)).norm(p=2, dim=2)
_, rank = pairwise.view(-1).sort()
rank_x = rank / m
rank_p = rank % m
used = set()
nn = [None for _ in range(n)]
rest = n
for x, p in zip(rank_x, rank_p):
  if p not in used and nn[x] is None:
    nn[x] = p
    used.add(p)
    rest -= 1
    if rest == 0:
      break
nn = torch.Tensor(nn)

sorry for the ugly second half :frowning: can’t think of better solutions.

1 Like

Thanks for the help. Unfortunately this is a lot slower than what I have right now, and I’m really not sure why.