How to load part of pre trained model?

I’m afraid not

The keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

if name not in new model , it will raise KeyError
but I gusee this may work for you

    def load_my_state_dict(self, state_dict):
 
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            own_state[name].copy_(param)
31 Likes