Weights are not loaded/saved after training


I need help! It looks like PyTorch does not save model state dictionary or does not apply it. I train and save a model. If the model is loaded in the same Jupyter Notebook, it works. However, if loaded in another notebook, the results of inference resemble those of an untrained model. In the same time, a simple classifier when loaded (similar to the one below, but without ResNet18) works ok.

Any help or suggestion is appreciated. Thank you!

Below are the details.
I am training an image classifier based on pretrained ResNet18 included into my custom model class like this:

class CLF_resnet18(nn.Module):

    def __init__(self, num_class, dropout=0.2):
        super(CLF_resnet18, self).__init__()
        self.resnet = torchvision.models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(512, 128)
        self.flatten = nn.Flatten()
        self.layer_1 = nn.Linear(128, 32)
        self.layer_2 = nn.Linear(32, 8)
        self.layer_out = nn.Linear(8, num_class) 
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
        self.batchnorm1 = nn.BatchNorm1d(32)
        self.batchnorm2 = nn.BatchNorm1d(8)
    def forward(self, x):
        x = self.resnet(x)
        x = self.layer_1(self.flatten(x))
        x = self.batchnorm1(x)
        x = self.relu(x)
        x = self.dropout(x)  
        x = self.layer_2(x)
        x = self.batchnorm2(x)
        x = self.relu(x)
        x = self.dropout(x)    
        x = self.layer_out(x)
        return x

I initialize the model:

model = CLF_resnet18(len(LABELS))
model = model.apply(initialize_weights)
model = model.to(device)

I save like the manual suggests (model state dictionary):

model_full_name = os.path.join(model_path, state_dict_fname )
params = model.state_dict()
torch.save(params, model_full_name)


model_load = CLF_resnet18(len(LABELS))

OS: Ubuntu 20.04, anaconda environment with latest PyTorch (conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch)

Did you apply the same preprocessing the the samples in both notebooks?
If so, did you call model.eval() to disable dropout and to use the running stats in batchnorm layers?
Once this is done, check the model output for a static input (e.g. torch.ones) and calculate the abs().max() error.