Looking for a function that resets model to an older model before validation error increased


so in the passed I have been implementing the following strategy inspired by the function torch.optim.lr_scheduler.ReduceLROnPlateau:

  • after a patience number of epochs calculate the validation error of the current model
  • if validation error of current model is lower than the previously calculated validation error then save the current model
  • if the above condition does not hold then reset the model to the previously saved model and step down the learning rate

The above strategy is similar to ReduceLROnPlateau in that it steps down the learning rate when the validation error did not decrease after a patience number of epochs, however, using my method I am not stuck with the model that has already overfitted.

The biggest issue with my approach is that the way I have written the code is it requires a lot of coding to apply it to different applications (I am yet to write it in a manner in which it is generalisable). I would very much rather not do this each time.

Is there any chance that there is some function in pytorch that does something similar to this? If that were the case it would be awesome and I would appreciate it if someone could inform me about it. If that is not the case is there a reason why it is not? It shouldn’t be to hard to implement. Is there some inherent flaw with the strategy that I might be overlooking?