I need to calculate L2 distance between all elements in a mini-batch. The following code is find on the web. The question is that is this implementation efficient? Is there any API to do such operation?
def distanceMatrix(mtx1, mtx2):
"""
mtx1 is an autograd.Variable with shape of (n,d)
mtx1 is an autograd.Variable with shape of (n,d)
return a nxn distance matrix dist
dist[i,j] represent the L2 distance between mtx1[i] and mtx2[j]
"""
m = mtx1.size(0)
p = mtx1.size(1)
mmtx1 = torch.stack([mtx1]*m)
mmtx2 = torch.stack([mtx2]*m).transpose(0, 1)
dist = torch.sum((mmtx1 - mmtx2)**2, 2).squeeze()
return dist
Thanks~