Efficient Distance Matrix Computation

Hi All,

For the project I’m working on right now I need to compute distance matrices over large batches of data. I have two matrices X and Y, where X is nxd and Y is mxd. Then the distance matrix D is nxm and contains the squared euclidean distance between each row of X and each row of Y.

So far I’ve implemented this in a few different ways but each has their issues and I’m hoping someone more experienced with pytorch might be able to help me get an implementation that matches my needs. I’ll go through each approach and the related issues below.

Approach A - Direct Expanded

def expanded_pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    if y is not None:
         differences = x.unsqueeze(1) - y.unsqueeze(0)
    else:
        differences = x.unsqueeze(1) - x.unsqueeze(0)
    distances = torch.sum(differences * differences, -1)
    return distances

This approach adds extra dimensions to compute the difference between all combinations of rows and columns at once. This requires a lot of memory and is slow.

Approach B - Quadratic Expansion

def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y = x
        y_norm = x_norm.view(1, -1)

    dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
    return dist

This approach requires less memory and is faster than the above. However this suffers from rounding errors and in my setup leads to numerical instability. This seems to be a well documented issue with this approach and existing libraries will use thresholding to default to the direct computation when needed.

Approach C - Direct row-wise

def row_pairwise_distances(x, y=None, dist_mat=None):
    if y is None:
        y = x
    if dist_mat is None:
        dtype = x.data.type()
        dist_mat = Variable(torch.Tensor(x.size()[0], y.size()[0]).type(dtype))

    for i, row in enumerate(x.split(1)):
        r_v = row.expand_as(y)
        sq_dist = torch.sum((r_v - y) ** 2, 1)
        dist_mat[i] = sq_dist.view(1, -1)
    return dist_mat

The above is numerically stable and has a lower memory footprint than A but it is far slower than approach B. I think that it should be possible to modify C to be faster but I don’t know enough about the autograd engine and pytorch data storage to make the tweaks myself. Does anybody have any suggestions?

21 Likes

Hello. I am doing a project that need to calculate this distance too. Firstly I want to express my gratitude to you for sharing such elaborate approach.
I firstly just deal with the calculation by approch A and out of memory occurs, in the pytorch document it seems that there are not good solution to date.
I had no idea calculating (a-b)^2 by a^2+b^2-2ab, I think this is the best way for speed and memory consideration by using ready-made pytorch api.
I think the best way to deal with numerical stability may be writing a cffi extension and using cuda to calculating the distance directly.

Glad there is something helpful here for you! In the end I managed to get a suitable solution by adapting Approach B to clamp the distance matrix to be positive and ensure the diagonal is zero when needed:

def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    # Ensure diagonal is zero if x=y
    # if y is None:
    #     dist = dist - torch.diag(dist.diag)
    return torch.clamp(dist, 0.0, np.inf)

For my code this was enough to keep things stable during optimization. I think that in general you might be correct that we’d need a cffi extension.

21 Likes

@AtheMathmo: For pytorch version 0.2.0_4 (I’m not sure about the latest pytorch version), I found that torch.clamp(dist, 0.0, np.inf) didn’t deal with nan values that sometimes arises along the diagonal when running pairwise_distances(x, y=None). To mitigate this issue, I replaced the line with the following:

dist[dist != dist] = 0 # replace nan values with 0

2 Likes

I expand both x ( n x d ) and y ( m x d) as follows:

n = x.size(0)
m = y.size(0)
d = x.size(1)

x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)

Then calculate dist[i,j] = ||x[i,:]-y[j,:]||^2 in the following way:

dist = torch.pow(x - y, 2).sum(2) 

Both x and y are torch.Tensor. Is this exactly what you want? Or is there a more efficient way?

15 Likes

Thank you so much !
With your function I was able, moving the matrix on cuda, to
compute pairwise distances 210/220 times faster with respect
to scipy.spatial.distance.pdist !
I checked this on matrices from 1000 to 10000 samples of flattened
cifar10 images.
There is a small error with respect to the scipy function, I estimated
the relative error (max / median, taken over each pairwise distance)
to be around 1e-5, which is fine for me.

I think your approach is almost identical to the author’s first approach. Redundant.

1 Like

This answer is MEGA AESTHETIC.
thanks!!

also shout out to @AtheMathmo for posing the similar problem I had in terms that helped me understand my own situation

1 Like

Hi, thank you for your Approach B - Quadratic Expansion. Pretty impressive!
But may I ask is there a similar way for calculating the L1 distance(||x - y||) efficiently?

Note that methods above are not as numerically stable as torch.norm(input[:, None] - input, dim=2, p=2) or pdist which has recently been implemented for faster computations of the distance matrix containing upper triangular values.

Example:

a = torch.tensor([[ 1.4335, -1.0990, -0.8586],
        [ 2.1553,  2.7028, -0.8020],
        [ 1.0524,  0.1599, -0.0374]])

def dist(x):
    assert len(x.size()) == 2
    norm = (x ** 2).sum(1).view(-1, 1)
    dn = (norm + norm.view(1, -1)) - 2.0 * (x @ x.t())
    return dn.sqrt()

def dist2(x):
    return th.norm(x[:, None] - x, dim=2, p=2)

then

dist(a)
tensor([[9.7656e-04, 3.8701e+00, 1.5507e+00],  # --> entry (0,0) MUST be zero
        [3.8701e+00, 0.0000e+00, 2.8752e+00],
        [1.5507e+00, 2.8752e+00, 0.0000e+00]])

dist2(a) 
tensor([[0.0000, 3.8701, 1.5507],
        [3.8701, 0.0000, 2.8752],
        [1.5507, 2.8752, 0.0000]])

torch.pdist(a) --> tensor([3.8701, 1.5507, 2.8752])
6 Likes

Can we use the currently available torch.pairwise_distance, or does it effectively use your first approach?

No, I don’t think so. What we have here is numpy / scipy cdist (so the two sets of points do not have to be the same ones, i.e. do not need to be of an equal number) whereas torch pairwise_distance does require that.
It fails for something like (First dimension is batch size, N, D)

x = torch.zeros((3, 5, 2))
y = torch.zeros((3, 2, 2))

x[0, 0, 0] = 1
torch.pairwise_distance(x, y, p=2)

However, you may vote: https://github.com/pytorch/pytorch/issues/15253

GOD I love this answer, thanks man.

Can this be done for a custom distance function like wasserstein distance

1 Like

Correction:

dist = dist - torch.diag(dist.diag)

should be

dist = dist - torch.diag(dist.diag())

Very helpful. Now can this be done batch wise? I mean, if a is a 3D tensor of shape batch x m x hidden size, how should we apply dist2 so that we get a tensor back of shape batch x m x m?

For those looking at this post nowadays (2021), torch.cdist is available.

4 Likes

Only for 1-, 2-, 3- etc. norms from what it looks like. Anyone have a typical speed gain for torch.cdist?

I think that scipy.stats.wasserstein_distance would be a good starting point for this. The source code mostly uses standard NumPy functionality for which I think there are compatible PyTorch functions. Not exactly sure how that would translate to the .view() approach of B, though. If generating the pairwise distance matrix is the main desired output, I have a working Numba implementation that is ~130x faster than using cdist(x, y, metric=scipy.stats.wasserstein_distance)

May I ask for the numba solution you mentioned? I’m not familiar with numba and so far only managed to get a 2x boost