Broadcasting distance function

I am trying to get the distances between each part of tensor1(N,x,y) and tensor2(x,y) as a tensor of size N.

for exsample:

a = torch.ones(2,2)
b = torch.zeros(4,2,2)

I expect to get distance(a,b) = [4,4,4,4], but torch.dist() does not allow me to specify whether to reduce along a given dim and only gives me the overall average which is a single number, [4].

What is the best way to do it rather than using a for loop? I want such operation to be excuted parallelly on GPU.

You could manually calculate the distance via:

((a.expand(1, -1, -1) - b)**2).sum(dim=[1,2])
> tensor([4., 4., 4., 4.])
1 Like