 # Batched Pairwise Distance

(Jonathan) #1

I have tensors X of shape `BxNxD` and Y of shape `BxNxD`.

I want to compute the pairwise distances for each element in the batch, i.e. I a `BxMxN` tensor.

How do I do this?

There is some discussion on this topic here: https://github.com/pytorch/pytorch/issues/9406, but I don’t understand it as there are many implementation details while no actual solution is highlighted.

A naive approach would be to use the answer for non-batched pairwise distances as discussed here: Efficient Distance Matrix Computation, i.e.

``````import torch
import numpy as np

B = 32
N = 128
M = 256
D = 3

X = torch.from_numpy(np.random.normal(size=(B, N, D)))
Y = torch.from_numpy(np.random.normal(size=(B, M, D)))

def pairwise_distances(x, y=None):
x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_t = torch.transpose(y, 0, 1)
y_norm = (y**2).sum(1).view(1, -1)
else:
y_t = torch.transpose(x, 0, 1)
y_norm = x_norm.view(1, -1)

dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)

out = []
for b in range(B):
out.append(pairwise_distances(X[b], Y[b]))
print(torch.stack(out).shape)
``````

How can I do this without looping over B?
Thanks

(Pietro Astolfi) #2

Hi,

my solution was:

``````def pairwise_distances(x, y):
'''
Modified from https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3
Input: x is a bxNxd matrix
y is an optional bxMxd matirx
Output: dist is a bxNxM matrix where dist[b,i,j] is the square norm between x[b,i,:] and y[b,j,:]
i.e. dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2
'''
x_norm = x.norm(dim=2)[:,:,None]
y_t = y.permute(0,2,1).contiguous()
y_norm = y.norm(dim=2)[:,None]

dist = x_norm + y_norm - 2.0 * torch.bmm(x, y_t)