# Problem with a custom loss function

Hello! I am writing a custom loss function, and I am encountering a weird issue, which may be related to the way gradients are computed, but I am not sure, so any help would be appreciated. In a nutshell, I am testing a model. I simulate artificial data, and then I check the recovered outcome of the neural network. If that outcome is close to the simulated one, then the code is working well. Now, I have a custom loss function where all works fine, and it includes this:

``````class TrendModelLoss(nn.Module):

def __init__(self, theta, scaling, nT, YL):

super(TrendModelLoss, self).__init__()
# Fixed quantities
self.theta    = theta
self.scaling  = scaling
self.nT       = nT
self.YL       = YL
# Estimated parameters
self.sigma    = nn.Parameter(torch.tensor([-1.0]))
self.m0       = nn.Parameter(torch.tensor([0.0]))

def forward(self, predictions, targets):

# Rescale sigma
sigma_exp = torch.exp(self.sigma)

# Initialize m_t as None
m_t = None

# Allocate space for the vector m_t for all time steps
m_vector = torch.zeros_like(predictions)

# Some code that transforms the predictions into m_vector...

# Calculate likelihood for y_t
squared_diff = (m_vector - targets) ** 2
loss = torch.sum(squared_diff) / (2 * sigma_exp ** 2) + nT / 2 * torch.log(sigma_exp ** 2)
- 0.5*(sigma_exp ** 2) # PRIORS

return loss

``````

This one works very well, and even the parameter sigma is estimated to be the one that I use to simulate artificial data. However, suppose that I just change one line in the loss function:

``````squared_diff = (((1.0 - 0.0) * m_vector + 0.0 * self.YL[:,0]) - targets) ** 2
``````

This should be totally equivalent to the previous code, given that I multiply by 0.0 or 1.0. If I train the model now, things go very bad and sigma explodes. I cannot recover properly the true simulated values of the network. The loss values are also extremely high.

Does anyone know what may be happening?

Thank you!

Hi Charly!

I suspect that the shapes may be different, leading to different broadcasting.

Try printing out the shapes of `m_vector`, `self.YL`, `self.YL[:, 0]`, and `targets`.

Best.

K. Frank

That was it! I was mixing 1D and 2D tensors. Thank you very much!