I am attempting to save my trained GAN models so that I can either resume training later, or use my saved models for inference.
My first attempt for saving the model was by saving the entire model as:
model = myGAN(epochs = 10, additional_args)
model.fit(data)
torch.save(model, "model.pth")
But this prompts the following error:
**PicklingError** : Can't pickle <class 'Model.myGAN.columnInfo'>: attribute lookup columnInfo on Model.myGAN failed
columnInfo
is a named tuple that I am using in my model which has access to the my data’s feature descriptions, but apparently it cannot be saved.
My second attempt for saving and loading the model was using the recommended Pytorch way:
model = myGAN(epochs = 10, additional_args)
model.fit(data)
torch.save(model.state_dict(), "model.pth")
model2 = myGAN(epochs = 5, additional_args)
model2.load_state_dict(torch.load("model.pth"))
But that returns the following error:
RuntimeError: Error(s) in loading state_dict for myGAN:
Unexpected key(s) in state_dict:
I believe this happens because some of the unexpected keys are only initialised when the model begins training, and thereby are not recognised by model2 before that. Are there any workarounds for this errors?
Edit1: Loading the model parameters only works if it is done on the same model I initialised and trained. e.g.:
model = myGAN(epochs = 10, additional_args)
model.fit(data)
torch.save(model.state_dict(), "model.pth")
model.load_state_dict(torch.load("model.pth"))