Need help with a loss function for a surrogate model

I was hoping to get advice on implementing a loss function in a surrogate model. I am working with micromechanical simulations of gas pores produced during processing in Additive Manufacturing, and due to the simulation time I put together a surrogate model that will predict the simulation output given a voxelized 3D pore volume. The target and output are single channel 3D stress fields with each voxel intensity representing a stress value, usually from around -100 to 300 or so. So I am comparing the similarity of the output mask to the target stress field. Using MSE seemed to work great, giving me an R2 of 99.87%, but a closer look at the histograms showed the mask stress distribution was a bit off as shown below.

While the network does well in capturing the general field morphology, in most cases it overshot the minimum and undershot the maximum so that it essentially squished the distribution together. To better match the distributions I added the Earth Mover’s Distance (Wasserstein distance) that takes the histograms of the stress fields and creates cdf’s to do L1 or L2 loss with. This helped match the distributions better, but it is still undershooting the maximum stress, though slightly better. The issue is that I am using the stress field outputs to calculate the stress concentration factor (SCF), which is dependent on the maximum stress and the average stress, so getting an accurate maximum stress is critical. I tried adding in an extra term for the squared error of the target and mask maximum stresses, but it really messed up the distributions, and hardly made a difference with the maximum stress. I definitely have more parameter exploration to do, but does anyone have any ideas?

Also, here is my loss function for reference in case someone spots something. I made sure to use Kornia so the histograms are differentiable, although I wasn’t sure if that was even needed.

import torch
import torch.nn as nn
from torch.nn import L1Loss, MSELoss
from kornia.enhance import histogram
    

class EarthMoversMSELoss(nn.Module):
    def __init__(self, bins=25, p=1, do_root=True, alpha=0.5):
        super(EarthMoversMSELoss, self).__init__()
        self.bins = bins
        self.p = p
        self.do_root = do_root
        self.alpha = alpha
        self.mae_loss = L1Loss()
        self.mse_loss = MSELoss()
        
    def forward(self, mask, target):
        mask = torch.flatten(mask, start_dim=1)
        mask_min = torch.min(mask)
        mask_max = torch.max(mask)

        target = torch.flatten(target, start_dim=1)
        target_min = torch.min(target)
        target_max = torch.max(target)

        min_val = torch.min(mask_min, target_min)
        max_val = torch.max(mask_max, target_max)
        bins = torch.linspace(min_val.item(), max_val.item(), self.bins).to(mask.device)

        mask_hist = histogram(mask, bins, bandwidth=torch.tensor(0.9))
        target_hist = histogram(target, bins, bandwidth=torch.tensor(0.9))
        
        #mae_loss = self.mae_loss(mask, target)
        mse_loss = self.mse_loss(mask, target)
        emd_loss = self.emd_loss(mask_hist, target_hist, p=self.p, do_root=self.do_root)
        #max_val_loss = self.mse_loss(mask_max, target_max)

        return (1 - self.alpha) * mse_loss + self.alpha * emd_loss# + 0.01 * max_val_loss
    
    def emd_loss(self, mask_hist, target_hist, p=1, do_root=True):
        # Compute the Earth Mover's Distance loss
        mask_cdf = torch.cumsum(mask_hist, dim=-1)
        target_cdf = torch.cumsum(target_hist, dim=-1)

        if p == 1: # Compute L1
            emd = self.mae_loss(mask_cdf, target_cdf)

        elif p == 2: # Compute L2
            emd = self.mse_loss(mask_cdf, target_cdf)
            if do_root:
                emd = torch.sqrt(emd)

        else: # Compute p-norm
            emd = torch.mean(torch.pow(torch.abs(mask_cdf - target_cdf), p), dim=-1)
            if do_root:
                emd = torch.pow(emd, 1 / p)

        return torch.mean(emd) # Mean of all batches