Explanation of torch_cluster knn

Hello, I’m trying to compute a batched version of KNN. At the moment I’m looping over scipy’s cKDTree.

I saw that PyTorch geometric has a GPU implementation of KNN. However, I find that the documentation is not very clear the x and y input variables are matrices of points times features. Now, if I run the example code

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y)

the output of the example code is a 2x4 tensor. I was expecting a 2x2, e.g. for each point in y, its two nearest neighbors.
Can someone explaint what is this output?

Also, if I want to run this with x and y being batched vectors (BxNxF), would I have to flatten the tensors and create a kind of one-hot-encoding batch_x that indicates the batch?

Thanks for your help

I’m not completely sure as I haven’t used this method before, but based on the output:

print(assign_index)
# tensor([[0, 0, 1, 1],
#         [0, 1, 2, 3]])

I would assume that the first row gives the indices of y while the second row gives the indices of the nearest neighbors in x for the y index.

Let’s change the example a bit and see if this is true.

# make perfect matches
x = torch.Tensor([[-1, 0], [1, 0], [1, -1], [-1, 0]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]]) # first sample should match 0 and 3, second should match 1 and 2
batch_y = torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y)
print(assign_index)
# tensor([[0, 0, 1, 1],
#         [0, 3, 1, 2]])

Based on this I guess my assumption is correct.

Isn’t N already defining the batch dimension?

I don’t know how batch_x and batch_y are used and don’t fully understand the docs, but again I’m not familiar enough with pytorch_geometric so let’s wait for some experts.

Thanks for the explanation, your example makes sense.

Let me elaborate on the batch part of the question while we wait for the experts :slight_smile:
Imagine you have a cloud of 2D (F) 10 points (N) x and some 5 points y for which you want to find the NN at time t. Now, the points in x can change over time and you need to find the NN of another set of y points.

So, you want to limit the search of y_t to x_t. If you append all batches of x in an NxF matrix, you would be searching over different timestamps.

1 Like

Ah OK, thanks for the explanation.
If I thus want to combine both examples into a single operation I could just assign my new example to timestep 1?
This output seems to show that it’s working fine:

x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1], [-1, 0], [1, 0], [1, -1], [-1, 0]])
batch_x = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
y = torch.Tensor([[-1, 0], [1, 0], [-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0, 1, 1])
assign_index = knn(x, y, 2, batch_x, batch_y)
print(assign_index)
# tensor([[0, 0, 1, 1, 2, 2, 3, 3],
#         [0, 1, 2, 3, 4, 7, 5, 6]])
# equal to previous answer with offset where offsets are
# tensor([[0, 0, 0, 0, 2, 2, 2, 2],
#         [0, 0, 0, 0, 4, 4, 4, 4]])

The “second half” of the answer is matching the result of my previous code snippet with the offset mentioned below.