Unable to load pretrained weight into custom model in Pytorch Lightning

Hello folks,

I want to retrain a custom model with my data. I can load the pretrained weights (.pth file) into the model in Pytorch and it runs but I want more functionality and refactored the code into Pytorch Lightning. I am having trouble loading the pretrained weight into the Pytorch Lightning model. The Pytorch Lightning code works but I have limited data and don’t have enough data to train from scratch.

Pytorch Model:

class BDRAR(nn.Module):
      def __init__(self, stat=1):
          super(BDRAR, self).__init__()
          resnext = ResNeXt101()
          self.layer0 = resnext.layer0
          self.layer1 = resnext.layer1
          self.layer2 = resnext.layer2
          self.layer3 = resnext.layer3
          self.layer4 = resnext.layer4

My Lightning model looks like this:

 class liteBDRAR(pl.LightningModule):
         def __init__(self):
              super(liteBDRAR, self).__init__()
              self.model = BDRAR()
              print('Model Created!')

         def forward(self, x):
              return self.model(x)

My run:

KeyError: 'state_dict 
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR.load_from_checkpoint(path,  strict=False)
trainer = pl.Trainer(fast_dev_run=True, gpus=1)

I get the following error:

    "keys = model.load_state_dict(checkpoint["state_dict"], strict=strict) **KeyError: 'state_dict**'"

I will appreciate any help. Thank you so much.


Would it work, if you load the state_dict into the internal model without using the Lightning method via liteBDRARobject.model.load_state_dict()?

Hello @ptrblck, It works when I load it to the pytorch model without using Lightning as shown below:

def main():
    net = BDRAR().cuda().train()
    print("Network Created")

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], momentum=args['momentum'])
    print("Training Started.....")
    if len(args['snapshot']) > 0:
        print(('training resumes from \'%s\'' % args['snapshot']))
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')),strict=False)
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth')))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(log_path, 'w').write(str(args) + '\n\n')
    print("Net is: ")

I’m not familiar with Lightning enough, but are you able to access the “plain” PyTorch model from the Lightning wrapper and could then load the state_dict into it directly?

Hi @ptrblck, many thanks for your reply. I appreciate it! I could access the plain PyTorch model from Lightning wrapper but I got an error when I tried to load the model.

I got help from another forum and I can load the pretrained model in Lightning now.

Hi @JohnDuke259,

can you please share your solution on how to load pytorch model into lightning module? I’am looking for the same thing.


Hi @erict,

Sorry I just saw this notification.

In your Lightning Class:

class liteCalc(pl.LightningModule):

    def __init__(self):
        super(liteCalc, self).__init__()
        self.model = Calc()

    def load_model(self, path):
        self.model.load_state_dict(torch.load(path, map_location='cuda:0'), strict=False)
        print('Model Created!')

    def forward(self, x):
        return self.model(x)

Outside Lightning class before you create your Trainer :

path = './ckpt/x.pth'
model = liteCalc()

I hope this helps.

Hi @JohnDuke259,

yes this is working.

Hi @erict , you are welcome! :grinning: