I have two multidimensional tensors t1
and t2
, shaped (b,c,h,w)
. I would like to calculate the pairwise distance between every pair in the batch, that is, the distance ||t1[i]-t2[j]||
for all i,j=1,2,...,b
. Do notice that
-
every element
t1[i]
ort2[j]
is in itself a tensor of shape(c,h,w)
-
I expect a result with shape
(b,b)
-
Ive tried using cdist after flattening, that is,
d=torch.cdist(t1.flatten(1),t2.flatten(2))
, but the metric (it seems) doesn’t seem to really result in small values for similar tensors and vice versa.
appreciate the help!