I need help! It looks like PyTorch does not save model state dictionary or does not apply it. I train and save a model. If the model is loaded in the same Jupyter Notebook, it works. However, if loaded in another notebook, the results of inference resemble those of an untrained model. In the same time, a simple classifier when loaded (similar to the one below, but without ResNet18) works ok.
Any help or suggestion is appreciated. Thank you!
Below are the details.
I am training an image classifier based on pretrained ResNet18 included into my custom model class like this:
class CLF_resnet18(nn.Module): def __init__(self, num_class, dropout=0.2): super(CLF_resnet18, self).__init__() self.resnet = torchvision.models.resnet18(pretrained=True) self.resnet.fc = nn.Linear(512, 128) self.flatten = nn.Flatten() self.layer_1 = nn.Linear(128, 32) self.layer_2 = nn.Linear(32, 8) self.layer_out = nn.Linear(8, num_class) self.relu = nn.ReLU() self.dropout = nn.Dropout(p=dropout) self.batchnorm1 = nn.BatchNorm1d(32) self.batchnorm2 = nn.BatchNorm1d(8) def forward(self, x): x = self.resnet(x) x = self.layer_1(self.flatten(x)) x = self.batchnorm1(x) x = self.relu(x) x = self.dropout(x) x = self.layer_2(x) x = self.batchnorm2(x) x = self.relu(x) x = self.dropout(x) x = self.layer_out(x) return x
I initialize the model:
model = CLF_resnet18(len(LABELS)) model = model.apply(initialize_weights) model = model.to(device)
I save like the manual suggests (model state dictionary):
model_full_name = os.path.join(model_path, state_dict_fname ) params = model.state_dict() torch.save(params, model_full_name)
model_load = CLF_resnet18(len(LABELS)) model_load.load_state_dict(torch.load(state_dict_path))
OS: Ubuntu 20.04, anaconda environment with latest PyTorch (conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch)