Pairwise distance between multidimentional tensors

I have two multidimensional tensors t1 and t2, shaped (b,c,h,w). I would like to calculate the pairwise distance between every pair in the batch, that is, the distance ||t1[i]-t2[j]|| for all i,j=1,2,...,b. Do notice that

  • every element t1[i] or t2[j] is in itself a tensor of shape (c,h,w)

  • I expect a result with shape (b,b)

  • Ive tried using cdist after flattening, that is, d=torch.cdist(t1.flatten(1),t2.flatten(2)), but the metric (it seems) doesn’t seem to really result in small values for similar tensors and vice versa.

appreciate the help!

I guess broadcasting would work, but I’m unsure if you want to sum the difference or use any other norm:

b, c, h, w = 2, 3, 4, 4

a = torch.randn(b, c, h, w)
res = a.unsqueeze(1) - a.unsqueeze(0)
res = res.sum(dim=[2, 3, 4])
print(res)
# tensor([[ 0.0000, -3.6586],
#         [ 3.6586,  0.0000]])