Saving and Loading Trained Model Parameters

I want to train from scratch the AlexNet model:

model = models.alexnet(pretrained=False)
num_features = model.classifier[6].in_features
features = list(model.classifier.children())[:-1] # Remove last layer
features.extend([nn.Linear(num_features, 4)]) # Add our layer with 4 outputs
model.classifier = nn.Sequential(*features) # Replace the model classifier

for param in model.features.parameters():
    param.require_grad = True

I want to save the parameters/hyperparmeters (weights, bias, model values/structure etc.).

This is the training function that I am using:

def train(model, dataloaders, criterion1, criterion2, optimizer, num_epochs):
    since = time.time()
    train_loss = []
    val_acc_history = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)
        for phase in ['train', 'validation']:
            if phase == 'train':
            running_loss = 0.0
            running_corrects = 0 
            for images, target in dataloaders[phase]:
                images =
                labels = target['label'].to(device)
                norm = target['norm_box'].to(device)
                with torch.set_grad_enabled(phase == 'train'):
                    out_bbox = model(images)
                    loss_bbox = criterion2(out_bbox, norm)
                    loss = loss_bbox

                    if phase == 'train':

                running_loss += loss.item() * images.size(0)
                running_corrects += IoU(norm, out_bbox, 0.3, batch_size)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            if phase == 'train':    
            epoch_acc = running_corrects / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            if phase == 'validation' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'validation':

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc)), 'weights/best_wts.pth')
    return model, train_loss, val_acc_history

When I load the state_dict:

wts = model.load_state_dict(torch.load('weights/best_wts.pth))

I get the message:

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

And when trying to look through the state_dict:


I get the error:

AttributeError: 'IncompatibleKeys' object has no attribute 'keys'

Any advice would be greatly appreciated.

1 Like

This message just states, that no incompatible keys were found, so you are good to go.
As this message might be confusing, we have a discussion here to disable it if no incompatible keys were found.

Thanks for the reply.

Could you please also confirm if this method does indeed save the parameters/hyperparameters (weights, bias, model values/structure etc.) of the model?
I hope that makes sense.

Your current code will only save the model.state_dict, i.e. all parameters of your model.
The model structure will not be saved, so you should always keep the model definition close to the checkpoint.
Also, if you would like to use this checkpoint for finetuning, I would recommend to store the optimizer.state_dict(), since some optimizers have internal parameters (e.g. Adam), which would be otherwise reset to their default values.
Have a look at the ImageNet example to see, how the checkpoint is created.

Thank you very much for the explanation and the link.

1 Like

is there a reason why the standard way doesnā€™t work for you:

see answer above:


Recommended approach for saving a model

There are two main approaches for serializing and restoring a model.

The first (recommended) saves and loads only the model parameters:, PATH)

Then later:

the_model = TheModelClass(*args, **kwargs)

The second saves and loads the entire model:, PATH)

Then later:

the_model = torch.load(PATH)

However in this case, the serialized data is bound to the specific classes
and the exact directory structure used, so it can break in various ways when
used in other projects, or after some serious refactors.