I have two multidimensional tensors t1
and t2
, shaped (b,c,h,w)
. I would like to calculate the pairwise distance between every pair in the batch, that is, the distance t1[i]t2[j]
for all i,j=1,2,...,b
. Do notice that

every element
t1[i]
ort2[j]
is in itself a tensor of shape(c,h,w)

I expect a result with shape
(b,b)

Ive tried using cdist after flattening, that is,
d=torch.cdist(t1.flatten(1),t2.flatten(2))
, but the metric (it seems) doesn’t seem to really result in small values for similar tensors and vice versa.
appreciate the help!