Resume Training from

I’m attempting to save and load best model through torch, where I’ve defined my training function as follows:

def train_model(model, train_loader, test_loader, device, learning_rate=1e-1, num_epochs=200):

    # The training configurations were not carefully selected.

    criterion = nn.CrossEntropyLoss()

    # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10.
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[65, 75], gamma=0.75, last_epoch=-1)
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

    # Evaluation
    eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)
    print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1, eval_loss, eval_accuracy))

    load_model = input('Load a model?')
    for epoch in range(num_epochs):

        if  epoch//2 == 0:
          write_checkpoint(model=model, epoch=epoch, scheduler=scheduler, optimizer=optimizer)          
          model, optimizer, epoch, scheduler = load_checkpoint(model=model, scheduler=scheduler, optimizer=optimizer)    
          for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] =
        # Training

        running_loss = 0
        running_corrects = 0

        for inputs, labels in train_loader:
            inputs = torch.FloatTensor(inputs)
            inputs =
            labels =

            # zero the parameter gradients

            # forward + backward + optimize
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds ==

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        # Evaluation
        eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)

        # Set learning rate scheduler

        print("Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy))

    return model

Where I’d like to be able to load a model, and start training from the epoch where model was saved.

So far I have methods to save model, optimizer,scheduler states and the epoch via some existing open discussions

def write_checkpoint(model, optimizer, epoch, scheduler):
  state = {'epoch': epoch + 1, 'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }
  filename = '/content/model_', filename + f'CP_epoch{epoch + 1}.pth')

def load_checkpoint(model, optimizer, scheduler, filename='/content/checkpoint.pth'):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        scheduler = checkpoint['scheduler']
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch, scheduler

But I can’t seem to come up with the logic of how I’d update the epoch to restart at the correct value rather than what value of epoch is in the for loop.

Would it be possible for you to save it in the model? This way it would get saved when saving the state_dict