Hello, I’m trying to compute a batched version of KNN. At the moment I’m looping over scipy’s cKDTree.
I saw that PyTorch geometric has a GPU implementation of KNN. However, I find that the documentation is not very clear the x
and y
input variables are matrices of points times features. Now, if I run the example code
x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
assign_index = knn(x, y, 2, batch_x, batch_y)
the output of the example code is a 2x4 tensor. I was expecting a 2x2, e.g. for each point in y, its two nearest neighbors.
Can someone explaint what is this output?
Also, if I want to run this with x
and y
being batched vectors (BxNxF), would I have to flatten the tensors and create a kind of one-hot-encoding batch_x
that indicates the batch?
Thanks for your help