I need to calculate L2 distance between all elements in a mini-batch. The following code is find on the web. The question is that is this implementation efficient? Is there any API to do such operation?
def distanceMatrix(mtx1, mtx2): """ mtx1 is an autograd.Variable with shape of (n,d) mtx1 is an autograd.Variable with shape of (n,d) return a nxn distance matrix dist dist[i,j] represent the L2 distance between mtx1[i] and mtx2[j] """ m = mtx1.size(0) p = mtx1.size(1) mmtx1 = torch.stack([mtx1]*m) mmtx2 = torch.stack([mtx2]*m).transpose(0, 1) dist = torch.sum((mmtx1 - mmtx2)**2, 2).squeeze() return dist