Problem with a custom loss function

Hello! I am writing a custom loss function, and I am encountering a weird issue, which may be related to the way gradients are computed, but I am not sure, so any help would be appreciated. In a nutshell, I am testing a model. I simulate artificial data, and then I check the recovered outcome of the neural network. If that outcome is close to the simulated one, then the code is working well. Now, I have a custom loss function where all works fine, and it includes this:

class TrendModelLoss(nn.Module):

  def __init__(self, theta, scaling, nT, YL):

    super(TrendModelLoss, self).__init__()
    # Fixed quantities
    self.theta    = theta
    self.scaling  = scaling
    self.nT       = nT
    self.YL       = YL 
    # Estimated parameters
    self.sigma    = nn.Parameter(torch.tensor([-1.0]))
    self.m0       = nn.Parameter(torch.tensor([0.0]))


  def forward(self, predictions, targets):

    # Rescale sigma
    sigma_exp = torch.exp(self.sigma)

    # Initialize m_t as None
    m_t = None

    # Allocate space for the vector m_t for all time steps
    m_vector = torch.zeros_like(predictions)

    # Some code that transforms the predictions into m_vector...

    # Calculate likelihood for y_t
    squared_diff = (m_vector - targets) ** 2
    loss = torch.sum(squared_diff) / (2 * sigma_exp ** 2) + nT / 2 * torch.log(sigma_exp ** 2)
    - 0.5*(sigma_exp ** 2) # PRIORS

    return loss

This one works very well, and even the parameter sigma is estimated to be the one that I use to simulate artificial data. However, suppose that I just change one line in the loss function:

squared_diff = (((1.0 - 0.0) * m_vector + 0.0 * self.YL[:,0]) - targets) ** 2

This should be totally equivalent to the previous code, given that I multiply by 0.0 or 1.0. If I train the model now, things go very bad and sigma explodes. I cannot recover properly the true simulated values of the network. The loss values are also extremely high.

Does anyone know what may be happening?

Thank you!

Hi Charly!

I suspect that the shapes may be different, leading to different broadcasting.

Try printing out the shapes of m_vector, self.YL, self.YL[:, 0], and targets.

Best.

K. Frank

That was it! I was mixing 1D and 2D tensors. Thank you very much!