Checkpoint Loading not Working As Expected

Hi all! Recently, I started implementing checkpoint saves on my code, as it has become difficult running everything in one go. I save everything as prescribed on the PyTorch tutorials, however when I resume training, the network behaves as if the checkpoint hasn’t been loaded properly. You can see this behaviour in the figure below. I’ll also include relevant extracts from my code, in case anybody can spot if I am doing something silly. Thank you!



class Solver():    

    def __init__(self ... ):

        self.model = model
        self.optimizer = optimizer(model.parameters(), **optimizer_arguments)
        ....
        self.model_name = model_name
        ....
        self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
                                                           step_size=learning_rate_scheduler_step_size,
                                                           gamma=learning_rate_scheduler_gamma)
        self.start_epoch = 1
        self.start_iteration = 1
        ....

        if use_last_checkpoint:
            self.load_checkpoint()

    def train(self, train_loader, validation_loader):

        model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler

        if torch.cuda.is_available():
            torch.cuda.empty_cache()  # clear memory
            model.cuda(self.device)  # Moving the model to GPU

        iteration = self.start_iteration

        for epoch in range(self.start_epoch, self.number_epochs+1):
            ...... 
                if phase == 'train':
                    model.train()
                else:
                    model.eval()
               .....
                with torch.no_grad():
                    ....
                    if phase == 'validation':
                        early_stop, save_checkpoint = self.EarlyStopping(
                            np.mean(losses))
                        self.early_stop = early_stop
                        if save_checkpoint == True:
                            validation_loss = np.mean(losses)
                            checkpoint_name = os.path.join(
                                self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
                            self.save_checkpoint(state={'epoch': epoch + 1,
                                                        'start_iteration': iteration + 1,
                                                        'arch': self.model_name,
                                                        'state_dict': model.state_dict(),
                                                        'optimizer': optimizer.state_dict(),
                                                        'scheduler': learning_rate_scheduler.state_dict()
                                                        },
                                                 filename=checkpoint_name
                                                 )
                ......
                if phase == 'train':
                    learning_rate_scheduler.step()
    ......

    def save_checkpoint(self, state, filename):

        torch.save(state, filename)

    def load_checkpoint(self, epoch=None):

        ....
            checkpoint_file_path = os.path.join(.....)
            self._checkpoint_reader(checkpoint_file_path)

    def _checkpoint_reader(self, checkpoint_file_path):
       ....
        checkpoint = torch.load(checkpoint_file_path)
        self.start_epoch = checkpoint['epoch']
        self.start_iteration = checkpoint['start_iteration']
        self.model.load_state_dict = checkpoint['state_dict']
        self.optimizer.load_state_dict = checkpoint['optimizer']

        for state in self.optimizer.state.values():
            for key, value in state.items():
                if torch.is_tensor(value):
                    state[key] = value.to(self.device)

        self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']

A close-up view from the training performance, without the outliers. As you can see, even after the first peak, the values don’t get close to the point where they were at the stop of training, as if something hasn’t loaded properly. However, when printing values and state_dicts while going through the code step-by-step, all seems to be there.

load_state_dict is a method, which should get the state_dict as its input.
Currently you are reassigning the state_dict to the function:

self.model.load_state_dict = checkpoint['state_dict']
self.optimizer.load_state_dict = checkpoint['optimizer']

Replace these lines of code with:

self.model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer'])

and rerun the code.

Thank you @ptrblck! I’m rerunning now and will let you know of the outcome!