Missing keys & unexpected keys in state_dict when loading self trained model

Thanks for your suggestions! @ptrblck

Can we load the state_dict of a previously trained NN for a new NN that has some changes in its structure. For example, I need the state_dict of previously trained NN partially, which i successfully did. Now, I add one more layer to my new NN and it shows error that the state_dict of NN which was loaded does not contain this output layer. How to proceed?


Could you try to load the state_dict using model.load_state_dict(torch.load(PATH), strict=False)?


Thank u so much @ptrblck, it worked perfectly!

In my case it would not throw any errors anymore, but it wouldn’t correctly load the state_dict either.
In the end I just had to do this to just remove the “.module” part.

Can you explain more clearly how to wrap the current model in nn.Dataparallel? Like can you give an example?
I have the error Missing key(s) in state_dict: and actually when I save the model, I just use torch.save().
I am new to pytorch, thanks so much!

You could just wrap the model in nn.DataParallel and push it to the device:

model = Model(input_size, output_size)
model = nn.DataParallel(model)

I would not recommend to save the model directly, but instead its state_dict as explained here.
Also, after you’ve wrapped the model in nn.DataParallel, the original model will be accessible via model.module, so you might want to store the state_dict via torch.save(model.module.state_dict(), 'file_name.pt').


I use torch.save(model.state_dict(), 'file_name.pt') to save the model. If I don’t use nn.DataParallel to save model, I also don’t need it when I load it, right?
If so, when I torch.save(model.state_dict(), 'mymodel.pt') in one py file during training, I try to load it in a new file with model.load_state_dict(torch.load('mymodel.pt)) , it gives me the error
RuntimeError: Error(s) in loading state_dict for model: Missing key(s) in state_dict: "l1.W", "l1.b", "l2.W", "l2.b".
My model is a self-defined model with two self-constructed layers l1 and l2, each has parameter W and b.
Can you give me any answer to it? Thanks!

Could you post the model definition, please?

Yes, that is correct.

I also am facing the same issue. I don’t use nn.DataParallel.
I am using a slightly modified version of [this repo] in a Kaggle notebook https://github.com/aitorzip/PyTorch-CycleGAN.
Here’s how I save:

                'epoch': epoch,
                'model_state_dict': netG_A2B.state_dict(),
                'optimizer_state_dict': optimizer_G.state_dict(),
                'loss_histories': loss_histories,
                }, f'netG_A2B_{epoch:03d}.pth')

Here’s how I load:

checkpoint = torch.load(weights_path)['model_state_dict']

And here’s a representative snippet of the (longer) error message:

Missing key(s) in state_dict: "1.weight", "1.bias", "4.weight", "4.bias", "7.weight", "7.bias".
Unexpected key(s) in state_dict: "model.1.weight", "model.1.bias", "model.4.weight", "model.4.bias", "model.7.weight", "model.7.bias". 

I opted for this method to fix it:

checkpoint = torch.load(weights_path, map_location=self.device)['model_state_dict']
for key in list(checkpoint.keys()):
    if 'model.' in key:
        checkpoint[key.replace('model.', '')] = checkpoint[key]
        del checkpoint[key]

I was facing the same issue when I was using PyTorch lightning.

@williamFalcon Please take note that thius was the way I solved the problem.

You can replace module keys in state _dict as follows:-

pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}

Ideally, if you use DataParallel save the checkpoint file as follows for inference:-
torch.save(model.module.state_dict(), 'model_ckpt.pt') .

This might also be useful for running inference using CPU, at a later time.

Ref: https://stackoverflow.com/a/61854807/3177661


Thanks, you save my day

It worked for me, no more error.

works for me, thank you

This worked. thanks a lot!!

I got this error for a different reason - in my case, it was because I was trying to load variables that were not yet defined in the model. In the saved model, they were defined in a non-init function,

class NN(nn.Module):
    def __init__(self):

    def save(self, model_save_path):
        os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
        save_dict = {}
        torch.save(save_dict, model_save_path)

    def fit(self, X_train, Y_train, X_val, Y_val, **kwargs):
        train_data_mean = X_train.mean(axis=0)
        train_data_std = X_train.std(axis=0)
        self.register_buffer('train_data_mean', torch.tensor(train_data_mean.values))
        self.register_buffer('train_data_std', torch.tensor(train_data_std.values))

        self.model.fit(X_train, Y_train, eval_set=[(X_train, Y_train), (X_val, Y_val)], **kwargs)

    def forward(self, inputs):
        return inputs

The solution was initializing filler variables in the constructor, to be replaced by the loaded values

class NN(nn.Module):
    def __init__(self):
        self.register_buffer('train_data_mean', torch.zeros(52))
        self.register_buffer('train_data_std', torch.ones(52))

    def save(self, model_save_path):
        os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
        save_dict = {}
        torch.save(save_dict, model_save_path)

    def fit(self, X_train, Y_train, X_val, Y_val, **kwargs):
        train_data_mean = X_train.mean(axis=0)
        train_data_std = X_train.std(axis=0)
        self.register_buffer('train_data_mean', torch.tensor(train_data_mean.values))
        self.register_buffer('train_data_std', torch.tensor(train_data_std.values))

        self.model.fit(X_train, Y_train, eval_set=[(X_train, Y_train), (X_val, Y_val)], **kwargs)

    def forward(self, inputs):
        return inputs

the original answer is really dangerous and misleading … using strict False is a TERRIBLE idea without understanding the basic reason about why you are even getting the error. As explained by other posters, the model loads incompletely and gives all possible WRONG answers. Please take this answer away …if there’s a MOD somewhere here

The usage of strict=False can indeed be dangerous, if you ignore the returned object and do not check for expected missing or unexpected keys.

The mentioned use case of adding a single new layer to the model and trying to load the “old” state_dict sounds like a valid use case for strict=False, but you could also take the better approach of manipulating the state_dict directly.

So I got this error while training a GAN based Federated Learning model. I have trained my model normally without differential privacy and there was no issue, but as soon as I attached PrivacyEngine to the discriminator, this error shows up after one round of global training and doesn’t update the global model;

here is how I save my model:
torch.save(self.netG.state_dict(), ‘%s/netG_epoch_%d.pth’ % (os.path.join(‘.’, ‘saved_model’), epoch))
torch.save(self.netD.state_dict(), ‘%s/netD_epoch_%d.pth’ % (os.path.join(‘.’, ‘saved_model’), epoch))

here is how I update my model at the client level:
def update_model(self, new_weights_D, new_weights_G):
Update the client’s models with the new global parameters.
# Update the discriminator model

    # Update the generator model

    print("Updated the client models with global parameters.")

here is how I try to load the model at the server level:
def update_global_models(self, new_weights_D, new_weights_G):

    # Now, load the adjusted state dicts into the models

I get this error after one round of training:
raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}’.format(
RuntimeError: Error(s) in loading state_dict for GradSampleModule:
Missing key(s) in state_dict: “_module.main.0.weight”, “_module.main.2.weight”, “_module.main.3.weight”, “_module.main.3.bias”, “_module.main.5.weight”, “_module.main.6.weight”, “_module.main.6.bias”, “_module.main.8.weight”, “_module.main.9.weight”, “_module.main.9.bias”, “_module.main.11.weight”.
Unexpected key(s) in state_dict: “main.0.weight”, “main.2.weight”, “main.3.weight”, “main.3.bias”, “main.5.weight”, “main.6.weight”, “main.6.bias”, “main.8.weight”, “main.9.weight”, “main.9.bias”, “main.11.weight”.

can someone tell me what I’m doing wrong?
