# K nearest neighbor in pytorch

Hi, I have tensor `size [12936x4098]` and after computing a similarity using `F.cosine_similarity`, get a tensor of size `12936`. For a given point, how can I get the k-nearest neighbor?
Using clustering methods defined in `sklearn` or `scipy` is very slow and required copy tensor from GPU to CPU.

Thank you

1 Like

After calculating the distance between your test sample and , you could probably use `topk` to get the nearest neighbors.
Would this work for you:

``````data = torch.randn(100, 10)
test = torch.randn(1, 10)

dist = torch.norm(data - test, dim=1, p=None)
knn = dist.topk(3, largest=False)

print('kNN dist: {}, index: {}'.format(knn.values, knn.indices))
``````
13 Likes

Thank you, `topk` can do the work. But I need the topk for each point the `data`. `topk` may end up with some overlap. A point belonging to more than one. For example, i have a `12936x4096` tensor. I would like to get pairs of indices that i can merge/fusion (if k=2). The distance between the two pair should be minimum. I"m currently doing something like that

``````x = torch.randn(12936, 4096)
y = x
m = len(x)
dists = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, m) + \
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(m, m).t()
pairs = []
``````

How do I get the pair of index that can be merged?
Thank you.

If I understand the code correctly, you are calculating the squared distance between all points.
Using this distance you would like to get the (k=2) nearest neighbors? Iâ€™m not sure to understand the merge/fusion part. Could you give me a simple example using some random data points?

Hi, I compute the pairwise distance between the points. Let say, i have

``````x = torch.randn(128, 2048)
y = x
m = len(x)
dists #  torch.Size([128, 128])
``````

I want to find the closest neighbor to a given point. I managed to do it using numpy.

``````dists = dists.numpy()
ind = np.unravel_index(np.argsort(dists, axis=None), dists.shape)
idx1 = ind[0]
idx2 = ind[1]
``````

Thus, the closest point to `x[idx1[0]]` is the point at `x[idx2[0]]`. How can I do this in pytorch directly?
Thank you.

This code should work:

``````dist = torch.randn(12, 12)

# numpy
dist_np = dist.numpy()
ind_np = np.unravel_index(np.argsort(dist_np, axis=None), dist_np.shape)
idx1_np = ind_np[0]
idx2_np = ind_np[1]

# PyTorch
ind = torch.sort(dist.flatten()).indices
idx1 = ind / dist.size(1)
idx2 = ind % dist.size(0)

print((idx1_np == idx1.numpy()).all())
print((idx2_np == idx2.numpy()).all())
``````

Note that you might not get exactly the same results, if you have exactly duplicated dist values.
As far as I know, the behavior of numpy and PyTorch does not match for duplicated values.

1 Like

Great, Thank you, it working.

You can also use `torch.repeat_interleave` like this:

``````def knn(ref, query, k):
ref_c =torch.stack([ref] * query.shape[-1], dim=0).permute(0, 2, 1).reshape(-1, 2).transpose(0, 1)
query_c = torch.repeat_interleave(query, repeats=ref.shape[-1], dim=1)
delta = query_c - ref_c
distances = torch.sqrt(torch.pow(delta, 2).sum(dim=0))
distances = distances.view(query.shape[-1], ref.shape[-1])
sorted_dist, indices = torch.sort(distances, dim=-1)
return sorted_dist[:, :k], indices[:, :k]
``````

Thank you so much! Exactly what I needed

I implemented NN, KNN and KMeans on a project I am working on only using PyTorch. You can find the implementation here with an example: Nearest Neighbor, K Nearest Neighbor and K Means (NN, KNN, KMeans) only using PyTorch Â· GitHub

``````>>> import torch as th
>>> from clustering import KNN
>>> data = th.Tensor([[1, 1], [0.88, 0.90], [-1, -1], [-1, -0.88]])
>>> labels = th.LongTensor([3, 3, 5, 5])
>>> test = th.Tensor([[-0.5, -0.5], [0.88, 0.88]])
>>> knn = KNN(data, labels)
>>> knn(test)
tensor([5, 3])
``````
1 Like

Could you possibly provide the same data snippet, but for a batched test data, i.e torch.randn(5, 10) for example, since Iâ€™m having some trouble with that. Thanks!

I assume you are referring to this post.
If so, then this batched approach should work:

``````data = torch.randn(100, 10)
test_batch = torch.randn(5, 10)

for test in test_batch:
test = test.unsqueeze(0)
dist = torch.norm(data - test, dim=1, p=None)
knn = dist.topk(3, largest=False)
print('kNN dist: {}, index: {}'.format(knn.values, knn.indices))

> kNN dist: tensor([2.2806, 2.4122, 2.5661]), index: tensor([97, 26, 13])
kNN dist: tensor([2.7386, 2.8757, 3.0352]), index: tensor([87, 92, 26])
kNN dist: tensor([1.8009, 1.8642, 2.3949]), index: tensor([75, 15, 37])
kNN dist: tensor([2.1696, 2.2909, 2.5022]), index: tensor([80, 76, 28])
kNN dist: tensor([1.9649, 2.0842, 2.2254]), index: tensor([27, 62, 58])

dist2 = torch.norm(data.unsqueeze(1) - test_batch.unsqueeze(0), dim=2, p=None)
knn2 = dist2.topk(3, largest=False, dim=0)
print('kNN dist:\n{}\nindex:\n{}'.format(knn2.values, knn2.indices))
> kNN dist:
tensor([[2.2806, 2.7386, 1.8009, 2.1696, 1.9649],
[2.4122, 2.8757, 1.8642, 2.2909, 2.0842],
[2.5661, 3.0352, 2.3949, 2.5022, 2.2254]])
index:
tensor([[97, 87, 75, 80, 27],
[26, 92, 15, 76, 62],
[13, 26, 37, 28, 58]])
``````
1 Like

I want something like this but only all the points that are within a defined radius. Does pytorch has some algorithm for that as well like kdtree?

Assuming the defined radius is relative to the test sample, you could most likely filter out the invalid samples before calling `topk`.

@ptrblck If I have the tensor p = [0.01, 0.1, 0.04, 0.5, 0.24] and want to compute the top-3, and my differentiable top-k algorithm gives me a tensor([[0.3008, 0.6230, 0.4807, 0.8423, 0.7532]]), namely the inclusion probabilities, which means the probability of the corresponding index to be in top-k, then how can I â€śpickâ€ť the ordered values from the original tensor given their probability, like p_top3 =[0.5, 0.24, 0.1], in a differentiable way?

Iâ€™m unsure what â€śdifferentiable wayâ€ť would mean in this context.
Assuming the probabilities are created by a model: `tensor([[0.3008, 0.6230, 0.4807, 0.8423, 0.7532]])` and are thus differentiable.
Would it work if you just call `topk` on it as seen here?

``````model = nn.Linear(1, 5)
x = torch.randn(1, 1)
out = model(x)

print(out)
# tensor([[ 0.2601, -0.2404, -0.5585,  0.1546,  0.0896]],

t = torch.topk(out, k=3)
print(t)
# torch.return_types.topk(
# indices=tensor([[0, 3, 4]]))

t.values.mean().backward()