I am currently looking deeper in the usage of learning rate schedulers. Hereby, a critical thought came by: if my learning rate goes down by using StepLR or ReduceLRonPlateau, the decreased learning rate leads to a more detailed search for the local minimum. Which is what I want. However, from the gradient perspective, this continued more detailed search does not happen at my current best minimum (based on validation loss), but rather at the current spot within the n-dimensional optimization space. Hence, the optimizer starts a more detailed search at a less optimal spot it could. So, I was thinking of writing aOptimizationCheckpointManager. The main idea behind this class is to save the model weights whenever the validation loss decreases and to collect the learning rate of the optimizer. If the learning rate changes, the OptimizationCheckpointManager overwrites the model weights of the current model with the best weights seen so far. This way, the training can continue from the point of the lowest validation loss with the new learning rate.
Here’s my first version of the class:
class OptimizationCheckpointManager(object):
def __init__(self, model, optimizer):
self.best_model_state = model.state_dict()
self.best_optimizer_state = optimizer.state_dict()
self.best_loss = float('inf')
self.current_lr = self._get_lr(optimizer)
@staticmethod
def _get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def update(self, val_loss, model, optimizer):
updated_lr = self._get_lr(optimizer)
if val_loss < self.best_loss:
self.best_loss = val_loss
self.best_model_state = model.state_dict()
self.best_optimizer_state = optimizer.state_dict()
if updated_lr < self.current_lr:
self.current_lr = updated_lr
## update learning rate in best optimizer state
for param_group in optimizer.param_groups:
param_group['lr'] = updated_lr
return True
else:
return False
def return_best_states(self):
return self.best_model_state, self.best_optimizer_state
The order within the training loops goes something like:
training → validation → lr_scheduler → OptimizationCheckpointManager
Well, currently I do not see any big difference in my results. So, my questions are:
1.) Is this ideal actually any good?
2.) And if yes, might I have a logical error within my class.