Hi pytorch,

I was wondering if anyone had run across any code which finds the nearest neighbor of n points from a list of m points without replacement (n << m). so if point n1 and n2 are closest to point m1, the closer one is selected and the further one must recalculate the next nearest neighbor. Can anyone help?

Assuming 2-d Euclidean distance, `X: (n, d)`

, `P: (m, d)`

:

```
pairwise = (X.unsqueeze(1).expand(n, m, d) - P.unsqueeze(0).expand(n, m, d)).norm(p=2, dim=2)
_, rank = pairwise.view(-1).sort()
rank_x = rank / m
rank_p = rank % m
used = set()
nn = [None for _ in range(n)]
rest = n
for x, p in zip(rank_x, rank_p):
if p not in used and nn[x] is None:
nn[x] = p
used.add(p)
rest -= 1
if rest == 0:
break
nn = torch.Tensor(nn)
```

sorry for the ugly second half can’t think of better solutions.

1 Like

Thanks for the help. Unfortunately this is a lot slower than what I have right now, and I’m really not sure why.