Meaningless results after loading a saved model with scheduler

Hi, I trained a GAN on a dataset and then saved it using torch.save() . However when I load this saved model and use the generator to generate results, they are completely meaningless.

While loading I am using

generator.load_state_dict(path['basic_celeba_gen'], strict = False)

I had to set this flag to “False” because while training the generator I am using a Scheduler that updates the weights at certain epochs and if I don’t use the False flag, I am getting an error. However, I think this is what is causing the generation to fail after loading the model.

How do I fix this?

Using strict=False ignores missing or unexpected keys in the state_dict and should not be triggered by the learning rate scheduler, so I also think that using this argument is causing the issue.
Could you post the error message and explain why you think the lr scheduler is causing it?

Hi, thanks for the response. This is my error message:

RuntimeError: Error(s) in loading state_dict for Generator:                                                                                                                                                                                          Missing key(s) in state_dict: 
"main.0.weight", "main.1.weight", "main.1.bias", "main.1.running_mean", "main.1.running_var", 
"main.3.weight", "main.4.weight", "main.4.bias", "main.4.running_mean", "main.4.running_var", 
"main.6.weight", "main.7.weight", "main.7.bias", "main.7.running_mean", "main.7.running_var", 
"main.9.weight", "main.10.weight", "main.10.bias", "main.10.running_mean", "main.10.running_var", 
"main.12.weight".                                               
Unexpected key(s) in state_dict: "module.main.0.weight", "module.main.1.weight", 
"module.main.1.bias", "module.main.1.running_mean", "module.main.1.running_var", 
"module.main.1.num_batches_tracked", "module.main.3.weight", "module.main.4.weight", 
"module.main.4.bias", "module.main.4.running_mean", "module.main.4.running_var", 
"module.main.4.num_batches_tracked", "module.main.6.weight", "module.main.7.weight", 
"module.main.7.bias", "module.main.7.running_mean", "module.main.7.running_var", 
"module.main.7.num_batches_tracked", "module.main.9.weight", "module.main.10.weight", 
"module.main.10.bias", "module.main.10.running_mean", "module.main.10.running_var", 
"module.main.10.num_batches_tracked", "module.main.12.weight". 

I think the scheduler is causing my outputs to be meaningless because I had run the same training previously without the scheduler and didn’t have this error when loading the saved model and my outputs were what I would expect. I looked up a different thread on this forum by someone who had the same error and the suggested solution was to use strict = False to fix this error. But as you said, I think this argument is causing issues.

Alternatively, I am not sure how to account for missing weights when I load my trained model.

Yeah, you should be a bit careful about using this argument, as it’s ignoring the (valid) error and is thus not used by default.

Your issue seems to be caused by nn.DataParallel, which will add the .module tag to each parameter and buffer.
You could thus either remove the .module tags manually from the state_dict, wrap the model into nn.DataParallel before loading the state_dict, or store the state_dict from the underlying model using torch.save(model.module.state_dict(), PATH).

Thanks, will keep this in mind!

I am not sure what you mean by .module here, do you mean to pass torch.save() as

torch.save(m.module.state_dcit(), PATH)

where m is an instance of my model class?

Yes, since the .module keywords will be added by nn.DataParallel you could store the internal model’s state_dict by accessing m.module.state_dict() where m is the nn.DataParallel object:

model = nn.Linear(1, 1)
print(model.state_dict())
> OrderedDict([('weight', tensor([[-0.8219]])), ('bias', tensor([-0.6302]))])

model = nn.DataParallel(model)
print(model.state_dict())
> OrderedDict([('module.weight', tensor([[-0.8219]], device='cuda:0')), ('module.bias', tensor([-0.6302], device='cuda:0'))])

print(model.module.state_dict())
> OrderedDict([('weight', tensor([[-0.8219]], device='cuda:0')), ('bias', tensor([-0.6302], device='cuda:0'))])

Thanks for clarifying! Following your advice I was getting an error trying to do what you suggested:

AttributeError: 'Generator' object has no attribute 'module'

And when I printed my state_dict() using print(Generator.state_dict()) I got this:

 ('main.10.weight', tensor(<size>),  device='cuda:0')

So there was no module term in my weights.

But, I think I fixed it!

Checked my model class definition and found I had the condition:

if (device.type == 'cuda') and ngpu > 1:

    netG = nn.DataParallel(netG, list(range(ngpu)))

So I simply removed the condition of ngpu > 1 and it works fine now :slight_smile:

Thanks again!