Error when saving and loading GAN models

I am attempting to save my trained GAN models so that I can either resume training later, or use my saved models for inference.

My first attempt for saving the model was by saving the entire model as:

model = myGAN(epochs = 10, additional_args)
model.fit(data)

torch.save(model, "model.pth")

But this prompts the following error:

**PicklingError** : Can't pickle <class 'Model.myGAN.columnInfo'>: attribute lookup columnInfo on Model.myGAN failed

columnInfo is a named tuple that I am using in my model which has access to the my data’s feature descriptions, but apparently it cannot be saved.

My second attempt for saving and loading the model was using the recommended Pytorch way:

model = myGAN(epochs = 10, additional_args)
model.fit(data)

torch.save(model.state_dict(), "model.pth")

model2 = myGAN(epochs = 5, additional_args)
model2.load_state_dict(torch.load("model.pth"))

But that returns the following error:

RuntimeError: Error(s) in loading state_dict for myGAN:
	Unexpected key(s) in state_dict: 

I believe this happens because some of the unexpected keys are only initialised when the model begins training, and thereby are not recognised by model2 before that. Are there any workarounds for this errors?

Edit1: Loading the model parameters only works if it is done on the same model I initialised and trained. e.g.:

model = myGAN(epochs = 10, additional_args)
model.fit(data)

torch.save(model.state_dict(), "model.pth")

model.load_state_dict(torch.load("model.pth"))

You can try setting strict=False parameter while using model.load_state_dict() function.
That might help!

load_state_dict documentation

Unfortunately using strict = False did not work for me. However, I found a solution, but the workaround is a bit awkward:

model = myGAN(epochs = 10, additional_args)
model.fit(data)

torch.save({"epoch" : 10,
            "model_state_dict" : model.state_dict(),
            "G_optimizer_state_dict" : self.G_optimizer.state_dict(),
            "D_optimizer_state_dict" : self.D_optimizer.state_dict() }, 
            "model.pth")

model2 = myGAN()
model2 = myGAN.fit(data, epochs = 0)

checkpoint = torch.load("model.pth")
model2.load_state_dict(checkpoint['model_state_dict'])
model2.G_optimizer.load_state_dict(checkpoint['G_optimizer_state_dict'])
model2.D_optimizer.load_state_dict(checkpoint['D_optimizer_state_dict'])
 
model2.fit(data, epochs = 10)

I moved the epochs inside the fit function, but I still have to call fit with 0 epochs before loading the model to accept the saved parameters.

I am not sure why it didn’t work for you.
But the below code works.

import torch
import torch.nn as nn

class model(nn.Module):
    def __init__(self):
        super().__init__()
        self.m1 = nn.Linear(5,6)
        self.m2 = nn.Linear(6,7)

    def create_dynamic_linear(self):
        self.m3 = nn.Linear(3,3)

    def forward(self,):
        pass

if __name__ == '__main__':
    model1 = model()
    model1.create_dynamic_linear()
    torch.save(model1.state_dict(), 'model.pth')

    model2 = model()
    miss, unexp = model2.load_state_dict(torch.load('model.pth'), strict=False)

    print(miss, unexp)