How to implement a custom loss in pytorch

Hi everyone,

I’m currently working on implementing a custom loss function for my project. The idea is to add a loss function with a set of existing ones. Specifically, I’m introducing a novel component calculated as a weight * MSE between two maps obtained online during training.

However, I’ve encountered an issue where adjusting the weight of this new loss term doesn’t seem to have any impact on the final loss and accuracy of my model. Even when I set the weight to 1 or a significantly larger value, the results remain unchanged.

More precisely, even when the seeds are fixed, we can see a contribution from the added loss (weight*MSE) in the result. The problem appears to be centered around the weight itself, which seems to have no influence on the outcome.

The code follows this reasoning:

loss = a * loss1 + b * loss2 + c * loss3 + d * loss4  
#where d * loss4 is the custom implementation added by us

loss1 = nn.CrossEntropyLoss() 
loss2 = nn.KLDivLoss() 
loss3 = nn.MSE() 
loss4 = nn.MSE()

the types are the following:

loss: tensor([value], device=‘cuda:0’, grad_fn=)

loss1: tensor(value, device=‘cuda:0’, grad_fn=)
loss2: tensor(value, device=‘cuda:0’, grad_fn=)
loss3: tensor(value, device=‘cuda:0’, grad_fn=)
loss4: tensor(value4, device=‘cuda:0’, grad_fn=)

a,b,c,d types are: <class: ‘float’>

and the loss4 implementation has this structure:

loss4 = criterion(torch.tensor(factor1, requires_grad=True).cuda(0, non_blocking=True), torch.tensor(factor2, requires_grad=True).cuda(0, non_blocking=True))

Do you have any advice on how to proceed?
Thanks in advance

Hi Giovanni!

The factory function torch.tensor() instantiates new tensors from factor1
and factor2 that are no longer connected to the computation graph. Setting
requires_grad = True for these new tensors doesn’t reconnect them to
the existing computation graph, but, rather, makes them new leaves of the
computation graph.

So when you call loss.backward() on the full loss, gradients do not
backpropagate through factor1 and factor2 via the loss4 term. So
the weight of loss4 will, indeed, have no effect on your model’s training.

Why are you creating new tensors for loss4? The fix might be as simple as:

loss4 = criterion (factor1, factor2)

Best.

K. Frank

1 Like

Hello Giovanni,

It sounds to me like you’re trying to make a smaller network(or curve fitting) in order to optimize your loss function(s). In order to do that, we need to determine some method to train the smaller network on.

First, we can build such a model with the nn.Parameter class.

class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss1 = nn.CrossEntropyLoss() 
        self.loss2 = nn.KLDivLoss() 
        self.loss3 = nn.MSE() 
        self.loss4 = nn.MSE()

        #setting these to 0.25 so their total starts out as 1.0. 
        self.a = nn.Parameter(0.25)
        self.b = nn.Parameter(0.25)
        self.c = nn.Parameter(0.25)
        self.d = nn.Parameter(0.25)

    def forward(self, x):
        return self.loss1(x) * self.a + self.loss2(x) * self.b + self.loss3(x) * self.c + self.loss4(x) * self.d

Note: The above is by no means an endorsement of your setup. My initial thoughts without running this is that you’ll want something that self-regulates or constrains the total of (self.a + self.b + self.c + self.d) to be equal to 1 at all times. This way you can prevent exploding your loss function. Here’s an alternative function that might accomplish that:

def forward(self, x):
        y = 1/(self.a + self.b + self.c + self.d)
        return y * (self.loss1(x) * self.a + self.loss2(x) * self.b + self.loss3(x) * self.c + self.loss4(x) * self.d)

Second, you obviously will not have pre-designated targets for your custom loss model. So a reinforcement learning method is in order. And your main model’s training is an iterative process. By that, I mean that the benefit of any changes made to these values during training of the custom loss model might not make itself manifest until many iterations have been processed. Because of this, a DQN or PPO might be the best approach for training the custom loss model.

PPO is more stable. Here is a PyTorch tutorial for that: Reinforcement Learning (PPO) with TorchRL Tutorial — torchrl main documentation

3 Likes

Hi Frank! Thanks a lot for the answer, I’ll explain myself better.
Factor1 and factor2 are two images generated from two networks (one of which is frozen). My aim is to make the first and second network-generated images close to each other, so I’m using the MSE with a weight.

I’m very interested in what you were saying about the loss4 term. The two tensors (factor1 and factor2) are generated at each epoch, and added in the following way:

criterion = nn.MSE()
loss4 = criterion(torch.tensor(factor1, requires_grad=True), torch.tensor(factor2, requires_grad=True)).cuda()

loss = a * loss1 + b * loss2 + c * loss3 + d * loss4

From your experience, can you see if this is done correctly? Unfortunately, the “d” weight seems to have no influence at all on the results. Should I define the factor tensors differently?

Hi Johnson, thanks a lot for the answer. To further clarify, our primary objective is to bring two images, generated during training, closer together by optimizing an additional loss term.

I appreciate your suggestion to regularize the function by dividing it by the sum, I intend to incorporate it into the implementation.

Regarding the TorchRL implementation, you mean to give the a,b,c,d parameters to PPO algorithm to fine-tune them online during training?

The way I’d set up a problem like this is:

  1. Find the best a to d parameters. During this stage, set your main model parameters as small as possible. You want to train a lot of throw-away models quickly. Use a PPO algorithm for training the Custom Loss model.

  2. Train the main model. Set the Custom Loss model parameters to static. And scale up the model you intend to train to the size you want.

Note of Caution: There is an implicit assumption I’ve made here, which I think is very reasonable to assume: that is that the best combination of various loss functions will be the same, regardless of the scale of the model being trained, given the same data and problem.

2 Likes