I have two loss functions, one is a custom with both forward() and backward() called
MyMSELoss . Another loss function is directly from PyTorch
I want to do a weighted sum of the losses to get
combinedloss. When I multiplied
torch_mse by 2, both the loss and gradients are changed accordingly but for the
my_mse , they don’t change when it is multiplied by 2.
So my question is how to put a weight on the custom loss
my_mse when combining the two losses??
Here is the demo version of my code. Thanks in advance.
import torch from torch.autograd import Function import torch.nn.functional as F class MyMSELoss(Function): @staticmethod def forward(ctx, y_pred, y): ctx.save_for_backward(y_pred, y) return ( (y - y_pred)**2 ).mean() @staticmethod def backward(ctx, grad_output): y_pred, y = ctx.saved_tensors grad_input = 2 * (y_pred - y) / y_pred.shape return grad_input, None mseloss = F.mse_loss # Usage example y_pred = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) y = torch.tensor([2.0, 3.0, 4.0]) torch_mse = mseloss(y_pred, y) my_mse = MyMSELoss.apply(y_pred, y) combinedloss = torch_mse * 2 + my_mse * 2 combinedloss.backward() # Access the gradients gradients = y_pred.grad print('loss', combinedloss) print('gradients', gradients)