Setting Pearson Correlation Coefficient as a Loss really doesn't work well

Here is a Minimal Reproducible Example: Google Colab.

I’m trying to set up a Custom Loss, the Pearson Correlation Coefficient Loss (i.e. 1-PearsonCorrelationCoefficient obviously).

I created this class, thinking this would work easily:

class PearsonLoss(nn.Module):
    def __init__(self):
        super(PearsonLoss, self).__init__()

    def forward(x: torch.Tensor, y: torch.Tensor):
        # Flatten the tensors
        x = x.view(x.shape[0], -1)
        y = y.view(y.shape[0], -1)

        # Mean centering
        x_mean = torch.mean(x, dim=1, keepdim=True)
        y_mean = torch.mean(y, dim=1, keepdim=True)
        x = x - x_mean
        y = y - y_mean

        # Numerator
        cov = torch.sum(x * y, dim=1)

        # Denominator
        x_std = torch.sqrt(torch.sum(x**2, dim=1))
        y_std = torch.sqrt(torch.sum(y**2, dim=1))

        # Pearson Correlation Coefficient
        pcc = cov / ((x_std * y_std) + 1e-8)

        # Pearson **Loss**
        loss = 1 - pcc

        # Sum loss across the batch
        return loss.sum()

But it has a lot of trouble converging (cf. my minimal reproducible example which is trying to map an image to literally the same image with a super small UNet like architecture).

Obviously it would work really well with a MSE loss for example, so I’m wondering. Did I implement something wrong for my Pearson Correlation Coefficient Loss?

I have to use it for a “bigger project” that doesn’t fit in a MRE and, similarly, nothing converges. Since my NN cannot even do the identity function, I guess there is indeed a problem with my Class but I can’t seem to find it.

Any help is appreciated!

EDIT: sometimes it does converge but most of the times it looks like it “gets stuck” in a local minimum:

have you tried torch.corrcoef() ?

torch.sqrt(x) is not a stable function. if your x is close to 0, the gradient will be very high or NaN, because for a function y = sqrt(x), gradient dy/dx = -1/(2*sqrt(x))

also I’m not understanding why you want to use a Pearson correlation function to optimize “an image”. Something like MAE or MSE should work fine for images.