Hi All,

For the project I’m working on right now I need to compute distance matrices over large batches of data. I have two matrices X and Y, where X is nxd and Y is mxd. Then the distance matrix D is nxm and contains the squared euclidean distance between each row of X and each row of Y.

So far I’ve implemented this in a few different ways but each has their issues and I’m hoping someone more experienced with pytorch might be able to help me get an implementation that matches my needs. I’ll go through each approach and the related issues below.

### Approach A - Direct Expanded

```
def expanded_pairwise_distances(x, y=None):
'''
Input: x is a Nxd matrix
y is an optional Mxd matirx
Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
if y is not given then use 'y=x'.
i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
'''
if y is not None:
differences = x.unsqueeze(1) - y.unsqueeze(0)
else:
differences = x.unsqueeze(1) - x.unsqueeze(0)
distances = torch.sum(differences * differences, -1)
return distances
```

This approach adds extra dimensions to compute the difference between all combinations of rows and columns at once. This requires a lot of memory and is slow.

### Approach B - Quadratic Expansion

```
def pairwise_distances(x, y=None):
'''
Input: x is a Nxd matrix
y is an optional Mxd matirx
Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
if y is not given then use 'y=x'.
i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
'''
x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_norm = (y**2).sum(1).view(1, -1)
else:
y = x
y_norm = x_norm.view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
return dist
```

This approach requires less memory and is faster than the above. However this suffers from rounding errors and in my setup leads to numerical instability. This seems to be a well documented issue with this approach and existing libraries will use thresholding to default to the direct computation when needed.

### Approach C - Direct row-wise

```
def row_pairwise_distances(x, y=None, dist_mat=None):
if y is None:
y = x
if dist_mat is None:
dtype = x.data.type()
dist_mat = Variable(torch.Tensor(x.size()[0], y.size()[0]).type(dtype))
for i, row in enumerate(x.split(1)):
r_v = row.expand_as(y)
sq_dist = torch.sum((r_v - y) ** 2, 1)
dist_mat[i] = sq_dist.view(1, -1)
return dist_mat
```

The above is numerically stable and has a lower memory footprint than A but it is far slower than approach B. I think that it should be possible to modify C to be faster but I don’t know enough about the autograd engine and pytorch data storage to make the tweaks myself. Does anybody have any suggestions?