@Ismail_Elezi As @apaszke said, you can compute the similarity matrix for the L2 distance using only matrix operations.
Here is an implementation for your similarity_matrix
using only matrix operations. It can run on the GPU and is going to be significantly faster than your previous implementation.
# (x - y)^2 = x^2 - 2*x*y + y^2
def similarity_matrix(mat):
# get the product x * y
# here, y = x.t()
r = torch.mm(mat, mat.t())
# get the diagonal elements
diag = r.diag().unsqueeze(0)
diag = diag.expand_as(r)
# compute the distance matrix
D = diag + diag.t() - 2*r
return D.sqrt()
If you are not backpropagating through y
, no need to wrap it all in variables, just wrap the last result.