I was wondering if anyone had run across any code which finds the nearest neighbor of n points from a list of m points without replacement (n << m). so if point n1 and n2 are closest to point m1, the closer one is selected and the further one must recalculate the next nearest neighbor. Can anyone help?
Assuming 2-d Euclidean distance,
X: (n, d),
P: (m, d):
pairwise = (X.unsqueeze(1).expand(n, m, d) - P.unsqueeze(0).expand(n, m, d)).norm(p=2, dim=2) _, rank = pairwise.view(-1).sort() rank_x = rank / m rank_p = rank % m used = set() nn = [None for _ in range(n)] rest = n for x, p in zip(rank_x, rank_p): if p not in used and nn[x] is None: nn[x] = p used.add(p) rest -= 1 if rest == 0: break nn = torch.Tensor(nn)
sorry for the ugly second half can’t think of better solutions.
Thanks for the help. Unfortunately this is a lot slower than what I have right now, and I’m really not sure why.