Batched Pairwise Distance

The issue was in the slicing on the norms. I’ve reconciled the reshapes and confirmed equivalence to the original non-batch pairwise_distances function:

def batch_pairwise_squared_distances(x, y):
  '''                                                                                              
  Modified from https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3         
  Input: x is a bxNxd matrix y is an optional bxMxd matirx                                                             
  Output: dist is a bxNxM matrix where dist[b,i,j] is the square norm between x[b,i,:] and y[b,j,:]
  i.e. dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2                                                         
  '''                                                                                              
  x_norm = (x**2).sum(2).view(x.shape[0],x.shape[1],1)
  y_t = y.permute(0,2,1).contiguous()
  y_norm = (y**2).sum(2).view(y.shape[0],1,y.shape[1])
  dist = x_norm + y_norm - 2.0 * torch.bmm(x, y_t)
  dist[dist != dist] = 0 # replace nan values with 0
  return torch.clamp(dist, 0.0, np.inf)
2 Likes