How to implement a loss function that has a time based parameter?


I need to implement a loss function which has a time-based parameter, i.e. a decay component which is based on the number of epochs. In optimization algorithms this is already implemented.

I can implement an extension of a loss function but i am not sure where and how i can pass it a parameter about the epochs?

Is it possible to implement it in pytorch? Could you just show a simple example?

Thank you!

you can simply write a nn.Module which has iteration as input and compute the loss

thank you for the reply.

yes, I defined a nn.module, but how can I pass that epoch during training?

In training once I have defined my loss function, after that it is only loss.item() or loss.backward() right, could you give a simple example how I could pass the parameter?

thank you for the help.

Class customloss(nn module)
     Self loss= loss 
   Loss = self.loss(x,gt) 
 Do your decay stuff
   Return loss