Build your own loss function in PyTorch

@Ismail_Elezi As @apaszke said, you can compute the similarity matrix for the L2 distance using only matrix operations.
Here is an implementation for your similarity_matrix using only matrix operations. It can run on the GPU and is going to be significantly faster than your previous implementation.

# (x - y)^2 = x^2 - 2*x*y + y^2
def similarity_matrix(mat):
    # get the product x * y
    # here, y = x.t()
    r = torch.mm(mat, mat.t())
    # get the diagonal elements
    diag = r.diag().unsqueeze(0)
    diag = diag.expand_as(r)
    # compute the distance matrix
    D = diag + diag.t() - 2*r
    return D.sqrt()

If you are not backpropagating through y, no need to wrap it all in variables, just wrap the last result.

13 Likes