Distance between two tensors in batch mode

I have two tensors in my forward function with sizes torch.Size([64, 256] where the 64 is the batch size. I would like to compute the L2 distance between the two tensors/vectors.

So, I have:


I am able to do something like:

dist=torch.abs(b-a) # yields shape 64,256 which is the expected output

However, this seems too naive and I was wondering if there was a better distance measure I can use here which gives me back a 64,256 sized tensor (since I feed this into a linear layer).

I tried using PairwiseDistance but this gives me a back a tensor of size 256, instead of 64,256.

Any pointers would be great.

I don’t fully understand your question.

Some explenation about the L2 norm:
The L2 norm reduces the dimension of a multi dimensional vector to 1, e.g. in physical space, which is a vector of shape 3, the L2 distance has shape 1 (which we use as the distance in every day live, if you measure the distance (shape 1), between two objects in your room, each with coordinates of shape 3).
You can also calculate this for other cases than the 3 dimensional, but the output will always be of shape 1, e.g. for a vector of shape 256, the output will be 1d. You could also do this individually, i.e. see your 256 entries as 256 1d vectors, then the output shape would be 256. In the 1d case the L2 norm is the same as the absolute value.

Maybe you could specify: Do you want to know how to calculate the L2 distance? Or are you looking for another measure?

This should normally return a [64] tensor or [64, 1] tensor if keepdim=True

t1 = torch.randn(64, 256)
t2 = torch.randn(64, 256)

t_dist = F.pairwise_distance(t1, t2)


Hello @Peter1998: firstly thank you for the explanation. It certainly helped. I guess what I really am trying to find out is what this line does:


Would you say this is the l2 norm since it is the absolute value? I am not sure how this translates to a distance measure :frowning:

Apologies @Caruso: that is correct.

For 1d coordinates the absolute value is the common L2 distance. If you see your second dimension as 256 seperate 1d coordinates, then your code computes you the 256 corresponding distances. But: It is not the L2 norm of a 256d vector.