Error(s) in loading state_dict missing keys in inference model

Hi,
I have my model like this
from hrnet import HRNet

Model = HRNet(in_c, out_c )
class modyfied_model():

        def __init__(self, backbone):
        super(new_model, self).__init__()
        self.backbone = backbone
       self.conv1 = nn.Conv2d(128,64,3,2, padding=1)
       self.conv2 = nn.Conv2d(64,50,3,2, padding=1)
       self.conv3 = nn.Conv2d(50,50,3,2, padding=1)
                                      
    
    
def forward(self, x):
    x = self.backbone(x)
    x = self.conv1(x)
    #print(x.shape)
    x = self.conv2(x)
    #print(x.shape)
    x = self.conv3(x)
    #print(x.shape)
    #x = self.max_pool(x)
    #print(x.shape)
    x = x.view(2,25,1024)
    #print(x.shape)
    #x = self.flatten(x)
    #coord_x = self.pred_x(x)
    coord_x = x[0,:,:].reshape(1,25,1024)
    #print(coord_x.shape)
    #coord_y = self.pred_y(x)
    coord_y = x[1,:,:].reshape(1,25,1024)
    #print(coord_y.shape)
    return coord_x, coord_y

new_model = modified_model(Model)
… training and saving the state dictionary
for inference I use the same model
Model = HRNet(in, out)
modified_model()…
test_model = modified_model(Model)

PATH = ‘path/to/checkpoint/’
checkpoint = torch.load(PATH,map_location=torch.device(‘cpu’))
model.load_state_dict(checkpoint[‘state_dict’])
now it is giving error missing keys…

Hi,
Can you show the part where you save the model.

As per documentation, there are two ways of saving the model.
First - save state dict, where only the weights are saved torch.save(model.state_dict(), PATH) and than loading model requires creating the model class separately from loading weights model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH))

Second - saving and loading whole model torch.save(model, PATH) and model = torch.load(PATH)

I presume you have mixed the methods, but to be sure it would be great to see the saving code.

Saving and Loading Models — PyTorch Tutorials 1.13.1+cu117 documentation)%0Amodel.eval(),-NOTE

Hi,
Thanks for the reply,
I saved the model as follows:

checkpoint = {
‘epoch’: epoch + 1,
‘valid_loss_min’: valid_loss,
‘state_dict’: model.state_dict(),
‘optimizer’: optimizer.state_dict(),
}
def save_ckp(state, is_best, checkpoint_path, best_model_path):
“”"
state: checkpoint we want to save
is_best: is this the best checkpoint; min validation loss
checkpoint_path: path to save checkpoint
best_model_path: path to save best model
“”"
f_path = checkpoint_path
# save checkpoint data to the path given, checkpoint_path
torch.save(state, f_path)
# if it is a best model, min validation loss
if is_best:
best_fpath = best_model_path
# copy that checkpoint file to best path given, best_model_path
shutil.copyfile(f_path, best_fpath)