Efficient Distance Matrix Computation


(James) #1

Hi All,

For the project I’m working on right now I need to compute distance matrices over large batches of data. I have two matrices X and Y, where X is nxd and Y is mxd. Then the distance matrix D is nxm and contains the squared euclidean distance between each row of X and each row of Y.

So far I’ve implemented this in a few different ways but each has their issues and I’m hoping someone more experienced with pytorch might be able to help me get an implementation that matches my needs. I’ll go through each approach and the related issues below.

Approach A - Direct Expanded

def expanded_pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    if y is not None:
         differences = x.unsqueeze(1) - y.unsqueeze(0)
    else:
        differences = x.unsqueeze(1) - x.unsqueeze(0)
    distances = torch.sum(differences * differences, -1)
    return distances

This approach adds extra dimensions to compute the difference between all combinations of rows and columns at once. This requires a lot of memory and is slow.

Approach B - Quadratic Expansion

def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y = x
        y_norm = x_norm.view(1, -1)

    dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
    return dist

This approach requires less memory and is faster than the above. However this suffers from rounding errors and in my setup leads to numerical instability. This seems to be a well documented issue with this approach and existing libraries will use thresholding to default to the direct computation when needed.

Approach C - Direct row-wise

def row_pairwise_distances(x, y=None, dist_mat=None):
    if y is None:
        y = x
    if dist_mat is None:
        dtype = x.data.type()
        dist_mat = Variable(torch.Tensor(x.size()[0], y.size()[0]).type(dtype))

    for i, row in enumerate(x.split(1)):
        r_v = row.expand_as(y)
        sq_dist = torch.sum((r_v - y) ** 2, 1)
        dist_mat[i] = sq_dist.view(1, -1)
    return dist_mat

The above is numerically stable and has a lower memory footprint than A but it is far slower than approach B. I think that it should be possible to modify C to be faster but I don’t know enough about the autograd engine and pytorch data storage to make the tweaks myself. Does anybody have any suggestions?


Is this implementation of calculating L2 distance matrix efficient?
#2

Hello. I am doing a project that need to calculate this distance too. Firstly I want to express my gratitude to you for sharing such elaborate approach.
I firstly just deal with the calculation by approch A and out of memory occurs, in the pytorch document it seems that there are not good solution to date.
I had no idea calculating (a-b)^2 by a^2+b^2-2ab, I think this is the best way for speed and memory consideration by using ready-made pytorch api.
I think the best way to deal with numerical stability may be writing a cffi extension and using cuda to calculating the distance directly.


(James) #3

Glad there is something helpful here for you! In the end I managed to get a suitable solution by adapting Approach B to clamp the distance matrix to be positive and ensure the diagonal is zero when needed:

def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    # Ensure diagonal is zero if x=y
    # if y is None:
    #     dist = dist - torch.diag(dist.diag)
    return torch.clamp(dist, 0.0, np.inf)

For my code this was enough to keep things stable during optimization. I think that in general you might be correct that we’d need a cffi extension.


(Ruth) #4

@AtheMathmo: For pytorch version 0.2.0_4 (I’m not sure about the latest pytorch version), I found that torch.clamp(dist, 0.0, np.inf) didn’t deal with nan values that sometimes arises along the diagonal when running pairwise_distances(x, y=None). To mitigate this issue, I replaced the line with the following:

dist[dist != dist] = 0 # replace nan values with 0


(Fusheng Hao) #5

I expand both x ( n x d ) and y ( m x d) as follows:

n = x.size(0)
m = y.size(0)
d = x.size(1)

x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)

Then calculate dist[i,j] = ||x[i,:]-y[j,:]||^2 in the following way:

dist = torch.pow(x - y, 2).sum(2) 

Both x and y are torch.Tensor. Is this exactly what you want? Or is there a more efficient way?


(Alessio Ansuini) #6

Thank you so much !
With your function I was able, moving the matrix on cuda, to
compute pairwise distances 210/220 times faster with respect
to scipy.spatial.distance.pdist !
I checked this on matrices from 1000 to 10000 samples of flattened
cifar10 images.
There is a small error with respect to the scipy function, I estimated
the relative error (max / median, taken over each pairwise distance)
to be around 1e-5, which is fine for me.


(Derek Kim) #7

I think your approach is almost identical to the author’s first approach. Redundant.


(Hassan Muhammad) #8

This answer is MEGA AESTHETIC.
thanks!!

also shout out to @AtheMathmo for posing the similar problem I had in terms that helped me understand my own situation


#9

Hi, thank you for your Approach B - Quadratic Expansion. Pretty impressive!
But may I ask is there a similar way for calculating the L1 distance(||x - y||) efficiently?


(Ehsan M Kermani) #10

Note that methods above are not as numerically stable as torch.norm(input[:, None] - input, dim=2, p=2) or pdist which has recently been implemented for faster computations of the distance matrix containing upper triangular values.

Example:

a = torch.tensor([[ 1.4335, -1.0990, -0.8586],
        [ 2.1553,  2.7028, -0.8020],
        [ 1.0524,  0.1599, -0.0374]])

def dist(x):
    assert len(x.size()) == 2
    norm = (x ** 2).sum(1).view(-1, 1)
    dn = (norm + norm.view(1, -1)) - 2.0 * (x @ x.t())
    return dn.sqrt()

def dist2(x):
    return th.norm(x[:, None] - x, dim=2, p=2)

then

dist(a)
tensor([[9.7656e-04, 3.8701e+00, 1.5507e+00],  # --> entry (0,0) MUST be zero
        [3.8701e+00, 0.0000e+00, 2.8752e+00],
        [1.5507e+00, 2.8752e+00, 0.0000e+00]])

dist2(a) 
tensor([[0.0000, 3.8701, 1.5507],
        [3.8701, 0.0000, 2.8752],
        [1.5507, 2.8752, 0.0000]])

torch.pdist(a) --> tensor([3.8701, 1.5507, 2.8752])