Loading a saved model for continue training

I tried to find a solution to that in other threads but I cannot find a problem like mine.

I am training a feed-forward NN and once trained save it using:
torch.save(model.state_dict(),model_name)

Then I get some more data points and I want to retrain the model on the new set, so I load the model using:
model.load_state_dict(torch.load(‘file_with_model’))

When i start training the model again, the error increases a lot. To check if it was a problem of the new points or the way I’m loading the model, I saved a trained model and load it again to retrain over the same set of points. When doing this, the error on the very first epoch increases a lot with respect to the error on the trained model.

Is this normal? Should I do anything more when loading a model for retrain?

Thank you very much :slight_smile:

2 Likes

If you trained your model using Adam, you need to save the optimizer state dict as well and reload that. Also, if you used any learning rate decay, you need to reload the state of the scheduler because it gets reset if you don’t, and you may end up with a higher learning rate that will make the solution state oscillate. Finally, if you have any dropout or batch norm in your model architecture, and you saved your model after a test loop (in which case model.eval() was called), make sure to call model.train() before the training loop.

9 Likes

What @kevinzakka said.

After saving using something like

state = {'epoch': epoch + 1, 'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(), 'losslogger': losslogger, }
torch.save(state, filename)

(losslogger is just something I use to keep track of the loss history; you can replace it with a tensorboard session or remove it)

…you then can re-load the model weights and the state of your optimizer and other things by calling something like

def load_checkpoint(model, optimizer, losslogger, filename='checkpoint.pth.tar'):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        losslogger = checkpoint['losslogger']
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch, losslogger
15 Likes

Wait, uh oh. What I said is no longer true. This worked for me with earlier versions of PyTorch, but now in PyTorch 0.4, this has stopped working.

It appears to work, but later when you’re training, you get an error from the optimizer’s (Adam in my case) optimizer.step() method:

    exp_avg.mul_(beta1).add_(1 - beta1, grad)
RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #4 'other'

Can anyone describe how to load properly with 0.4, so this doesn’t happen? Does one need to do optimizer.to(device) now, or something like that? My code is the same, and everything is .cuda()'d before the model is saved, so I don’t see why it’s expecting a non-cuda Tensor.

UPDATE: Found an answer in this issue.

So after you load from the checkpoint, when you move your model to cuda, you need to move the optimizer values as well, like so:

   model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger)
    model = model.to(device)
    # now individually transfer the optimizer parts...
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)

This works. Is there a more elegant solution, @apaszke?

18 Likes

Ok I see.
I think my problem is that I change optimizers when a certain training error is reached (change from Rprop to LBFGS by the way), so when I retrain i start again with Rprop. I checked starting with LBFGS for retraining and the error seems to behave well.

Thank you very much!

I see you also recover information from the optimizer. Is that because you use Adam? Is that general to other optimizers?

It depends. Vanilla SGD doesn’t use previous states, so there’d be no point recovering optimizer info for that. I’d say, if restarting the optimizer isn’t having an averse effect (e.g. you’re not noticing a giant jump when you restart), then you can get by without worrying about it.

Hey Scott,

thanks for this! Just saved me a lot of pain! :smiley:

Cheers,
Alex

1 Like

I dont think there is an issue in using Scott’s old code. I dont get any errors and the loss seems to pick up from where it left off

So the model in torchvision.models is trained by Vanilla SGD?

It will, unless you use GPU.

1 Like

what is losslogger? I notice that in my code there are:

criterion = nn.CrossEntropyLoss()

and

loss = criterion(output, target)

I am wondering if the losslogger is criterion in my case.

Thank you

Hello everyone !

I tried to retrain a model i’ve already trained myself on 7 epochs, by using your method @drscotthawley. But by using the code below, I’ve got a decrease in my accuracy at the 8th epoch. I don’t think this is normal and I don’t know where my error is.

ps : i used exactly the same parameters for the lr_scheduler, for training my 7 first epochs and the following

Thanks !

model_ft = get_instance_segmentation_model(num_classes)
# construct an optimizer
params = [p for p in model_ft.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.001,
                            momentum=0.9, weight_decay=0.0005) 

data_loader = torch.utils.data.DataLoader(
    dataset_train, batch_size=4, shuffle=True, num_workers=8,
    collate_fn=lambda x: tuple(zip(*x)))

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=2, shuffle=False, num_workers=8,
    collate_fn=lambda x: tuple(zip(*x)))


def load_checkpoint(model, optimizer, filename):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch

model_ft , optimizer, epoch = load_checkpoint(model_ft, optimizer, "path/my_model")

model_ft = model_ft.to(device)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=5,
                                               gamma=0.1) 

# now individually transfer the optimizer parts...
for state in optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.to(device)

Try if this works

state_dict = model.state_dict()

checkpoint = torch.load(filename)
avoid = ['fc.weight', 'fc.bias']
for key in checkpoint.keys():
    if key in avoid or key not in state_dict.keys():
        continue
    if checkpoint[key].size() != state_dict[key].size():
        continue
    state_dict[key] = checkpoint[key]
model.load_state_dict(state_dict)

Hi, I am saving SummaryWriter() object to save loss history.
It is showing this error: TypeError: cannot serialize '_io.BufferedWriter' object.

torch.save() doesn’t allow to save the python objects? How to resolve this error?

Thanks.

Thanks a lot!, Scott. You saved me a lot of pain!, too :smile: