For the project I’m working on right now I need to compute distance matrices over large batches of data. I have two matrices X and Y, where X is nxd and Y is mxd. Then the distance matrix D is nxm and contains the squared euclidean distance between each row of X and each row of Y.
So far I’ve implemented this in a few different ways but each has their issues and I’m hoping someone more experienced with pytorch might be able to help me get an implementation that matches my needs. I’ll go through each approach and the related issues below.
Approach A - Direct Expanded
def expanded_pairwise_distances(x, y=None): ''' Input: x is a Nxd matrix y is an optional Mxd matirx Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] if y is not given then use 'y=x'. i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 ''' if y is not None: differences = x.unsqueeze(1) - y.unsqueeze(0) else: differences = x.unsqueeze(1) - x.unsqueeze(0) distances = torch.sum(differences * differences, -1) return distances
This approach adds extra dimensions to compute the difference between all combinations of rows and columns at once. This requires a lot of memory and is slow.
Approach B - Quadratic Expansion
def pairwise_distances(x, y=None): ''' Input: x is a Nxd matrix y is an optional Mxd matirx Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] if y is not given then use 'y=x'. i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 ''' x_norm = (x**2).sum(1).view(-1, 1) if y is not None: y_norm = (y**2).sum(1).view(1, -1) else: y = x y_norm = x_norm.view(1, -1) dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)) return dist
This approach requires less memory and is faster than the above. However this suffers from rounding errors and in my setup leads to numerical instability. This seems to be a well documented issue with this approach and existing libraries will use thresholding to default to the direct computation when needed.
Approach C - Direct row-wise
def row_pairwise_distances(x, y=None, dist_mat=None): if y is None: y = x if dist_mat is None: dtype = x.data.type() dist_mat = Variable(torch.Tensor(x.size(), y.size()).type(dtype)) for i, row in enumerate(x.split(1)): r_v = row.expand_as(y) sq_dist = torch.sum((r_v - y) ** 2, 1) dist_mat[i] = sq_dist.view(1, -1) return dist_mat
The above is numerically stable and has a lower memory footprint than A but it is far slower than approach B. I think that it should be possible to modify C to be faster but I don’t know enough about the autograd engine and pytorch data storage to make the tweaks myself. Does anybody have any suggestions?