Missing keys & unexpected keys in state_dict when loading self trained model

the original answer is really dangerous and misleading … using strict False is a TERRIBLE idea without understanding the basic reason about why you are even getting the error. As explained by other posters, the model loads incompletely and gives all possible WRONG answers. Please take this answer away …if there’s a MOD somewhere here

1 Like

The usage of strict=False can indeed be dangerous, if you ignore the returned object and do not check for expected missing or unexpected keys.

The mentioned use case of adding a single new layer to the model and trying to load the “old” state_dict sounds like a valid use case for strict=False, but you could also take the better approach of manipulating the state_dict directly.

So I got this error while training a GAN based Federated Learning model. I have trained my model normally without differential privacy and there was no issue, but as soon as I attached PrivacyEngine to the discriminator, this error shows up after one round of global training and doesn’t update the global model;

here is how I save my model:
torch.save(self.netG.state_dict(), ‘%s/netG_epoch_%d.pth’ % (os.path.join(‘.’, ‘saved_model’), epoch))
torch.save(self.netD.state_dict(), ‘%s/netD_epoch_%d.pth’ % (os.path.join(‘.’, ‘saved_model’), epoch))

here is how I update my model at the client level:
def update_model(self, new_weights_D, new_weights_G):
“”"
Update the client’s models with the new global parameters.
“”"
# Update the discriminator model
self.netD.load_state_dict(new_weights_D)

    # Update the generator model
    self.netG.load_state_dict(new_weights_G)

    print("Updated the client models with global parameters.")

here is how I try to load the model at the server level:
def update_global_models(self, new_weights_D, new_weights_G):

    # Now, load the adjusted state dicts into the models
    self.global_model_D.load_state_dict(new_weights_D)

I get this error after one round of training:
load_state_dict
raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}’.format(
RuntimeError: Error(s) in loading state_dict for GradSampleModule:
Missing key(s) in state_dict: “_module.main.0.weight”, “_module.main.2.weight”, “_module.main.3.weight”, “_module.main.3.bias”, “_module.main.5.weight”, “_module.main.6.weight”, “_module.main.6.bias”, “_module.main.8.weight”, “_module.main.9.weight”, “_module.main.9.bias”, “_module.main.11.weight”.
Unexpected key(s) in state_dict: “main.0.weight”, “main.2.weight”, “main.3.weight”, “main.3.bias”, “main.5.weight”, “main.6.weight”, “main.6.bias”, “main.8.weight”, “main.9.weight”, “main.9.bias”, “main.11.weight”.

can someone tell me what I’m doing wrong?

    self.global_model_G.load_state_dict(new_weights_G)

Based on the error you are seeing it seems the PrivacyEngine has added the _module keywords to your parameters. You could either wrap the model into the PrivacyEngine again before loading the saved state_dict (assuming it will add the _module keywords again) or you could try to save the “raw” model instead. I’m not familiar with the PrivacyEngine so don’t know how exactly the parameters are manipulated. If its manipulation is similar to DDP you should be able to access the original parameters directly via the ._module attribute.

Oh
Thanks, let me try and get back to you