Batched Pairwise Distance

I have tensors X of shape BxNxD and Y of shape BxNxD.

I want to compute the pairwise distances for each element in the batch, i.e. I a BxMxN tensor.

How do I do this?

There is some discussion on this topic here: https://github.com/pytorch/pytorch/issues/9406, but I don’t understand it as there are many implementation details while no actual solution is highlighted.

A naive approach would be to use the answer for non-batched pairwise distances as discussed here: Efficient Distance Matrix Computation, i.e.

import torch
import numpy as np

B = 32
N = 128
M = 256
D = 3

X = torch.from_numpy(np.random.normal(size=(B, N, D)))
Y = torch.from_numpy(np.random.normal(size=(B, M, D)))


def pairwise_distances(x, y=None):
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    return torch.clamp(dist, 0.0, np.inf)


out = []
for b in range(B):
	out.append(pairwise_distances(X[b], Y[b]))
print(torch.stack(out).shape)

How can I do this without looping over B?
Thanks

Hi,

my solution was:

def pairwise_distances(x, y):
'''                                                                                              
Modified from https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3         
Input: x is a bxNxd matrix                                                                       
       y is an optional bxMxd matirx                                                             
Output: dist is a bxNxM matrix where dist[b,i,j] is the square norm between x[b,i,:] and y[b,j,:]
i.e. dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2                                                         
'''                                                                                              
x_norm = x.norm(dim=2)[:,:,None]                                                                 
y_t = y.permute(0,2,1).contiguous()                                                              
y_norm = y.norm(dim=2)[:,None]                                                                   
                                                                                                 
dist = x_norm + y_norm - 2.0 * torch.bmm(x, y_t)                                                 
                                                                                                 
return torch.clamp(dist, 0.0, np.inf)

but after a comparison with the “direct expansion” approach the approximation was so high that I didn’t use it (maybe there is a bug that I cannot see).

1 Like

The issue was in the slicing on the norms. I’ve reconciled the reshapes and confirmed equivalence to the original non-batch pairwise_distances function:

def batch_pairwise_squared_distances(x, y):
  '''                                                                                              
  Modified from https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3         
  Input: x is a bxNxd matrix y is an optional bxMxd matirx                                                             
  Output: dist is a bxNxM matrix where dist[b,i,j] is the square norm between x[b,i,:] and y[b,j,:]
  i.e. dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2                                                         
  '''                                                                                              
  x_norm = (x**2).sum(2).view(x.shape[0],x.shape[1],1)
  y_t = y.permute(0,2,1).contiguous()
  y_norm = (y**2).sum(2).view(y.shape[0],1,y.shape[1])
  dist = x_norm + y_norm - 2.0 * torch.bmm(x, y_t)
  dist[dist != dist] = 0 # replace nan values with 0
  return torch.clamp(dist, 0.0, np.inf)
2 Likes

I put this in my loss function and when I try to train my model with this, the weights become NaN after a few iterations. However, when I remove torch.bmm(x, y_t) , the model is able to train. Does anyone know what in torch.bmm() can cause this issue to occur?

At first I thought maybe I had to normalize my inputs x,y but that did not make a difference. I also tried using a much lower learning rate but it does not make a difference.

Not sure if this might help a bit Pairwise cosine distance.

This too Batched Distance