I need to linearly decay weight loss in pyTorch.This is the way I calculate my loss function
loss_A = criterion_id(G(real_A), real_A)
loss_B = criterion_id(F(real_B), real_B)
loss_id = (loss_A + loss_B) / 2
loss_ = lambda_ * loss_id
For example, the value of lambda_ is 10 and I want after some number of epochs lets assume 50, linearly decay the lambda_ value towards a small number like 1e-5.
You can define a function that returns the scaling factor as you want:
def get_scale_factor(epoch, init_factor=10, fin_factor=1e-5, num_epochs=50):
if epoch >= num_epochs:
slope = (fin_factor - init_factor) / num_epochs
return epoch * slope + init_factor
Note that I think scaling the weight like this will have the same effect as if you change the learning rate of your model. Currently, you are multiplying the loss with a factor, which will result in scaling the gradients, and similarly the updates applied to the model parameters. The same effect can be done by scaling the learning rate. So, you may look at how to adjust the learning rate: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
There are options such as linear, exponential, and many more.