Loss weight decay


I need to linearly decay weight loss in pyTorch.This is the way I calculate my loss function

loss_A = criterion_id(G(real_A), real_A)
loss_B = criterion_id(F(real_B), real_B)

loss_id = (loss_A + loss_B) / 2

loss_ = lambda_ * loss_id

For example, the value of lambda_ is 10 and I want after some number of epochs lets assume 50, linearly decay the lambda_ value towards a small number like 1e-5.

You can define a function that returns the scaling factor as you want:

def get_scale_factor(epoch, init_factor=10, fin_factor=1e-5, num_epochs=50):
     if epoch >= num_epochs:
         return fin_factor

     slope = (fin_factor - init_factor) /  num_epochs
     return epoch * slope + init_factor

Note that I think scaling the weight like this will have the same effect as if you change the learning rate of your model. Currently, you are multiplying the loss with a factor, which will result in scaling the gradients, and similarly the updates applied to the model parameters. The same effect can be done by scaling the learning rate. So, you may look at how to adjust the learning rate: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

There are options such as linear, exponential, and many more.