I have two loss functions, one is a custom with both forward() and backward() called MyMSELoss
. Another loss function is directly from PyTorch mseloss
.
I want to do a weighted sum of the losses to get combinedloss
. When I multiplied torch_mse
by 2, both the loss and gradients are changed accordingly but for the my_mse
, they don’t change when it is multiplied by 2.
So my question is how to put a weight on the custom loss my_mse
when combining the two losses??
Here is the demo version of my code. Thanks in advance.
import torch
from torch.autograd import Function
import torch.nn.functional as F
class MyMSELoss(Function):
@staticmethod
def forward(ctx, y_pred, y):
ctx.save_for_backward(y_pred, y)
return ( (y - y_pred)**2 ).mean()
@staticmethod
def backward(ctx, grad_output):
y_pred, y = ctx.saved_tensors
grad_input = 2 * (y_pred - y) / y_pred.shape[0]
return grad_input, None
mseloss = F.mse_loss
# Usage example
y_pred = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.tensor([2.0, 3.0, 4.0])
torch_mse = mseloss(y_pred, y)
my_mse = MyMSELoss.apply(y_pred, y)
combinedloss = torch_mse * 2 + my_mse * 2
combinedloss.backward()
# Access the gradients
gradients = y_pred.grad
print('loss', combinedloss)
print('gradients', gradients)