I have tensors X of shape
BxNxD and Y of shape
I want to compute the pairwise distances for each element in the batch, i.e. I a
How do I do this?
There is some discussion on this topic here: https://github.com/pytorch/pytorch/issues/9406, but I don’t understand it as there are many implementation details while no actual solution is highlighted.
A naive approach would be to use the answer for non-batched pairwise distances as discussed here: Efficient Distance Matrix Computation, i.e.
import torch import numpy as np B = 32 N = 128 M = 256 D = 3 X = torch.from_numpy(np.random.normal(size=(B, N, D))) Y = torch.from_numpy(np.random.normal(size=(B, M, D))) def pairwise_distances(x, y=None): x_norm = (x**2).sum(1).view(-1, 1) if y is not None: y_t = torch.transpose(y, 0, 1) y_norm = (y**2).sum(1).view(1, -1) else: y_t = torch.transpose(x, 0, 1) y_norm = x_norm.view(1, -1) dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) return torch.clamp(dist, 0.0, np.inf) out =  for b in range(B): out.append(pairwise_distances(X[b], Y[b])) print(torch.stack(out).shape)
How can I do this without looping over B?