K nearest neighbor in pytorch

Hi, I have tensor size [12936x4098] and after computing a similarity using F.cosine_similarity, get a tensor of size 12936. For a given point, how can I get the k-nearest neighbor?
Using clustering methods defined in sklearn or scipy is very slow and required copy tensor from GPU to CPU.

Thank you

2 Likes

After calculating the distance between your test sample and , you could probably use topk to get the nearest neighbors.
Would this work for you:

data = torch.randn(100, 10)
test = torch.randn(1, 10)

dist = torch.norm(data - test, dim=1, p=None)
knn = dist.topk(3, largest=False)

print('kNN dist: {}, index: {}'.format(knn.values, knn.indices))
13 Likes

Thank you, topk can do the work. But I need the topk for each point the data. topk may end up with some overlap. A point belonging to more than one. For example, i have a 12936x4096 tensor. I would like to get pairs of indices that i can merge/fusion (if k=2). The distance between the two pair should be minimum. I"m currently doing something like that

x = torch.randn(12936, 4096)
y = x
m = len(x)
dists = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, m) + \
        torch.pow(y, 2).sum(dim=1, keepdim=True).expand(m, m).t()
pairs = []

How do I get the pair of index that can be merged?
Thank you.

If I understand the code correctly, you are calculating the squared distance between all points.
Using this distance you would like to get the (k=2) nearest neighbors? I’m not sure to understand the merge/fusion part. Could you give me a simple example using some random data points?

Hi, I compute the pairwise distance between the points. Let say, i have

x = torch.randn(128, 2048)
y = x
m = len(x)
dists #  torch.Size([128, 128])

I want to find the closest neighbor to a given point. I managed to do it using numpy.

dists = dists.numpy()
ind = np.unravel_index(np.argsort(dists, axis=None), dists.shape)
idx1 = ind[0]
idx2 = ind[1]

Thus, the closest point to x[idx1[0]] is the point at x[idx2[0]]. How can I do this in pytorch directly?
Thank you.

This code should work:

dist = torch.randn(12, 12)

# numpy
dist_np = dist.numpy()
ind_np = np.unravel_index(np.argsort(dist_np, axis=None), dist_np.shape)
idx1_np = ind_np[0]
idx2_np = ind_np[1]

# PyTorch
ind = torch.sort(dist.flatten()).indices
idx1 = ind / dist.size(1)
idx2 = ind % dist.size(0)

print((idx1_np == idx1.numpy()).all())
print((idx2_np == idx2.numpy()).all())

Note that you might not get exactly the same results, if you have exactly duplicated dist values.
As far as I know, the behavior of numpy and PyTorch does not match for duplicated values.

1 Like

Great, Thank you, it working.

You can also use torch.repeat_interleave like this:

def knn(ref, query, k):
    ref_c =torch.stack([ref] * query.shape[-1], dim=0).permute(0, 2, 1).reshape(-1, 2).transpose(0, 1)
    query_c = torch.repeat_interleave(query, repeats=ref.shape[-1], dim=1)
    delta = query_c - ref_c
    distances = torch.sqrt(torch.pow(delta, 2).sum(dim=0))
    distances = distances.view(query.shape[-1], ref.shape[-1])
    sorted_dist, indices = torch.sort(distances, dim=-1)
    return sorted_dist[:, :k], indices[:, :k]

Thank you so much! Exactly what I needed

Could you possibly provide the same data snippet, but for a batched test data, i.e torch.randn(5, 10) for example, since I’m having some trouble with that. Thanks!

I assume you are referring to this post.
If so, then this batched approach should work:

data = torch.randn(100, 10)
test_batch = torch.randn(5, 10)

for test in test_batch:
    test = test.unsqueeze(0)
    dist = torch.norm(data - test, dim=1, p=None)
    knn = dist.topk(3, largest=False)
    print('kNN dist: {}, index: {}'.format(knn.values, knn.indices))

> kNN dist: tensor([2.2806, 2.4122, 2.5661]), index: tensor([97, 26, 13])
  kNN dist: tensor([2.7386, 2.8757, 3.0352]), index: tensor([87, 92, 26])
  kNN dist: tensor([1.8009, 1.8642, 2.3949]), index: tensor([75, 15, 37])
  kNN dist: tensor([2.1696, 2.2909, 2.5022]), index: tensor([80, 76, 28])
  kNN dist: tensor([1.9649, 2.0842, 2.2254]), index: tensor([27, 62, 58])

dist2 = torch.norm(data.unsqueeze(1) - test_batch.unsqueeze(0), dim=2, p=None)
knn2 = dist2.topk(3, largest=False, dim=0)
print('kNN dist:\n{}\nindex:\n{}'.format(knn2.values, knn2.indices))
> kNN dist:
  tensor([[2.2806, 2.7386, 1.8009, 2.1696, 1.9649],
          [2.4122, 2.8757, 1.8642, 2.2909, 2.0842],
          [2.5661, 3.0352, 2.3949, 2.5022, 2.2254]])
  index:
  tensor([[97, 87, 75, 80, 27],
          [26, 92, 15, 76, 62],
          [13, 26, 37, 28, 58]])
1 Like

I want something like this but only all the points that are within a defined radius. Does pytorch has some algorithm for that as well like kdtree?

Assuming the defined radius is relative to the test sample, you could most likely filter out the invalid samples before calling topk.

@ptrblck If I have the tensor p = [0.01, 0.1, 0.04, 0.5, 0.24] and want to compute the top-3, and my differentiable top-k algorithm gives me a tensor([[0.3008, 0.6230, 0.4807, 0.8423, 0.7532]]), namely the inclusion probabilities, which means the probability of the corresponding index to be in top-k, then how can I “pick” the ordered values from the original tensor given their probability, like p_top3 =[0.5, 0.24, 0.1], in a differentiable way?

I’m unsure what “differentiable way” would mean in this context.
Assuming the probabilities are created by a model: tensor([[0.3008, 0.6230, 0.4807, 0.8423, 0.7532]]) and are thus differentiable.
Would it work if you just call topk on it as seen here?

model = nn.Linear(1, 5)
x = torch.randn(1, 1)
out = model(x)

print(out)
# tensor([[ 0.2601, -0.2404, -0.5585,  0.1546,  0.0896]],
#         grad_fn=<AddmmBackward0>)

t = torch.topk(out, k=3)
print(t)
# torch.return_types.topk(
# values=tensor([[0.2601, 0.1546, 0.0896]], grad_fn=<TopkBackward0>),
# indices=tensor([[0, 3, 4]]))

t.values.mean().backward()
print(model.weight.grad)
# tensor([[-0.2168],
#         [ 0.0000],
#         [ 0.0000],
#         [-0.2168],
#         [-0.2168]])
1 Like

Yes, top-k operation is differentiable, but I was asking if I don’t use torch.topk() and use some other algorithms with better smoother gradient behaviour that outputs some inclusion probabilities, such as the work by this paper, “Differentiable Top-k Operator with Optimal Transport”, then how can I translate the inclusion probabilities (as my example) to ordered choice like the output of torch.topk()?

This is an updated link for the project (Nearest Neighbor, K Nearest Neighbor and K Means (NN, KNN, KMeans) implemented only using PyTorch · GitHub). It is under a new username