Thanks, that makes sense and now I found the error.
@apaszke and @melgor
@fmassa solution is not stable on the diagonal. The error happens in:
D = diag + diag.t() - 2*r
the diagonal here becomes zero (which is correct), but the gradient for whatever reason become NaN when we do the following command:
D = D.sqrt()
Could this mean that the diagonal entries are slightly smaller than 0 and then when we find the square root, they become NaNs? A cheap solution (which seems to work, for now) is to modify that line to:
D = diag + diag.t() - 2*r + 1e-7
though, I am not sure if that doesn't break anything else (I mean, the loss is decreasing, but not sure that all the computations are correct).
On a side note, if I want to normalize the X_similarity matrix, this doesn't seem to work:
X_similarity = (X_similarity - X_similarity.mean())/X_similarity.std()
When I tried on an experiment with tensors, it works but here that X_similarity is a variable, it is not working.