State_dict when it comes to test a model that is trained based on loading checkpoints

I trained a model using PyTorch and I stored the weights in which it had the minimum validation loss during training. Also, I stored the optimizer weights on that time. So, I trained my model for the second time by loading the weights that I stored from the first time. I wanted to test my model and generate the outputs of my model. But, unfortunately, I faced with such an error:

When I wanted to load the weights for the second_time training I used the following scripts:

    model = nn.DataParallel(model, device_ids = [0,1],output_device= [0,1])
    
    model = model.to(dev)

    model.load_state_dict(torch.load(path_to_the_saved_model's_weights, map_location=dev))
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=2e-7)

    
    # loading optimizer
    optimizer.load_state_dict(torch.load(path to the saved optimizer's weights))

I only changed learning rate value during the second_time training and when I want to test the model and generating the output of the model by loading the new weights I see such an error. How can I solve the problem?

This error is because the items in your model’s state_dict don’t match the state_dict you are trying to load. Just try to print out all the items in your model’s state_dict and the state_dict you get using torch.load. If they are not the same, you will get this error.

When you trained the model, did you use nn.DataParallel?

This is due to mismatch in model layer name during first training and second time.
just check the name of key value in state_dic saved in the first time training and compare it with the name of state_dic in second time. if there is any mismatch in key name this code may help. just rename keys in state_dics to current name in model.

def LoadStateDictCustom(self,StateDicPath):
        StatDic=torch.load(StateDicPath,map_location=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
        StatDic=StatDic['model_state_dict']
        StatDic2=OrderedDict()
        for key,value in StatDic.items():
            if(key=='OLD key name that must be changed  '):
                StatDic2['NEW key name for second time']=value
            else:
                StatDic2[key]=value
        self.load_state_dict(StatDic2)

Thans for your answer, could you please provide an example on how to use your function and also how the output of your solution should be used?

you could copy this function to your model class, pass StateDic Path as parameter to it,
this function just load the given StateDic and create new StateDic that is same as the first one, except rename some key value to the new key value that is proper to your model

class YourModel(nn.Module):
    def __init__(self, n_classes, in_chann):
        super(YourModel, self).__init__()


    def LoadStateDictCustom(self,StateDicPath):
        StatDic=torch.load(StateDicPath,map_location=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
        StatDic2=OrderedDict()
        for key,value in StatDic.items():
            if(key=='OLD key name that must be changed  '):
                StatDic2['NEW key name for second time']=value
            else:
                StatDic2[key]=value
        self.load_state_dict(StatDic2)

    def forward(self, x,trg=None):
         pass

Thanks a lot for your update but my main question is that I have several sequential convolutional layers in the forward function of my model’s class. In which section of the forward function I should call LoadStateDictCustom ? Should I call it in my script that is for training or I should call it inside the forward?

you should call it instead of this line

@saluei , Thank you again for your solution. I faced a new challenge. I want to load the weights and fine-tune my model while I changed some of the layers of my model. My plan is to freeze all the previous model parameters and only train specific layer. But according to this scenario, when I want to load weights I face the error in loading state_dict. How can I apply strict=False in your proposed solution.