Fastest way to find nearest neighbor for a set of points

Hi,

I’m trying to implement the Chamfer distance discussed in this paper : https://arxiv.org/abs/1612.00603. To do so, I need to do the following : given 2 unordered sets of same size N, find the nearest neighbor for each point. The only way I can think of doing this is to build a NxN matrix containing the pairwise distance between each point, and then take the argmin. However, I’m not sure if this approach fully takes advantage of how parallelizable computing NN is. Any suggestions on how to approach this problem ?

Thank you,
Lucas

You can efficiently build a similarity matrix using a few tensor operators, which are parallelized in both CPU and GPU.
Check Build your own loss function in PyTorch for an implementation.

3 Likes

Thank you Francisco for the fast reply. This looks like a valid solution

Or alternatively, you could look into using Faiss

1 Like

Hi again, after looking at the link you posted, it looks like it calculates the pairwise distance for points within the same set. Is there any way to adapt it for calculating the distance between 2 distinct sets ?

If you don’t have too many vectors, you can also consider using scipy’s cdist: https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html
It is exact but still quite fast, although an approximate nearest neighbours solution is of course faster.

1 Like

Hi,

Thanks for your answer. Unfortunately I’m using the distance as a loss function, so my implementation needs to be in pytorch so that I can back-propagate.

2 Likes

Yes, it’s possible to adapt it to 2 different sets. I’d recommend looking into implementations of pdist2 in matlab, and adapt it to pytorch (there are a number of them available) (typing from the phone, sorry for the lack of references)

1 Like

The previous answer can be adapted to compute the distances between two sets, as per Maximum mean discrepancy (MMD) and radial basis function (rbf) where P in that answer is the pairwise distances between all the elements of X and all the elements of Y. K and L are the within-class distances.

I am also trying to build a loss function based on distance of two unordered set for a different application. Also trying to figure out how to best do this in pytorch. I’d be very interested to know if you found any good solutions! Thanks.

Hi, Thanks for all the suggestions. Using the code posted I was able to implement NN for 2 sets. Now that I’m trying to implement it in batch, I need to fetch the diagonal of a 3d tensor. In other words I have a tensor of shape bs x num_points x points_dim, and I would like to fetch the diagonal along the batch to get a bs x num_points tensor. Is there an efficient way to do so ?

Thanks

def pairwise_dist(x, y):
    xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x, y.t())
    rx = (xx.diag().unsqueeze(0).expand_as(xx))
    ry = (yy.diag().unsqueeze(0).expand_as(yy))
    P = (rx.t() + ry - 2*zz)
    return P


def NN_loss(x, y, dim=0):
    dist = pairwise_dist(x, y)
    values, indices = dist.min(dim=dim)
    return values.mean()


def batch_pairwise_dist(a,b):
    x,y = a,b
    bs, num_points, points_dim = x.size()
    xx = torch.bmm(x, x.transpose(2,1))
    yy = torch.bmm(y, y.transpose(2,1))
    zz = torch.bmm(x, y.transpose(2,1))
    diag_ind = torch.arange(0, num_points).type(torch.cuda.LongTensor)
    rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
    ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
    P = (rx.transpose(2,1) + ry - 2*zz)
    return P


def batch_NN_loss(x, y, dim=1):
    assert dim != 0
    pdb.set_trace()
    dist = batch_pairwise_dist(x,y)
    values, indices = dist.min(dim=dim)
    return values.mean(dim=-1)

Here is how I ended up doing it. Not sure if I’m fetching the diagonal elements as fast as possible but it seems to be working.

3 Likes

How about this one:

def pairwise_dist(xyz1, xyz2):
    r_xyz1 = torch.sum(xyz1 * xyz1, dim=2, keepdim=True)  # (B,N,1)
    r_xyz2 = torch.sum(xyz2 * xyz2, dim=2, keepdim=True)  # (B,M,1)
    mul = torch.matmul(xyz2, xyz1.permute(0,2,1))         # (B,M,N)
    dist = r_xyz2 - 2 * mul + r_xyz1.permute(0,2,1)       # (B,M,N)
    return dist

I think this consumes less memory than your implementation.

1 Like