Hello folks,
I want to retrain a custom model with my data. I can load the pretrained weights (.pth file) into the model in Pytorch and it runs but I want more functionality and refactored the code into Pytorch Lightning. I am having trouble loading the pretrained weight into the Pytorch Lightning model. The Pytorch Lightning code works but I have limited data and don’t have enough data to train from scratch.
Pytorch Model:
class BDRAR(nn.Module):
def __init__(self, stat=1):
super(BDRAR, self).__init__()
resnext = ResNeXt101()
self.layer0 = resnext.layer0
self.layer1 = resnext.layer1
self.layer2 = resnext.layer2
self.layer3 = resnext.layer3
self.layer4 = resnext.layer4
My Lightning model looks like this:
class liteBDRAR(pl.LightningModule):
def __init__(self):
super(liteBDRAR, self).__init__()
self.model = BDRAR()
print('Model Created!')
def forward(self, x):
return self.model(x)
My run:
KeyError: 'state_dict
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR.load_from_checkpoint(path, strict=False)
trainer = pl.Trainer(fast_dev_run=True, gpus=1)
trainer.fit(bdrar)
I get the following error:
"keys = model.load_state_dict(checkpoint["state_dict"], strict=strict) **KeyError: 'state_dict**'"
I will appreciate any help. Thank you so much.
@williamFalcon
Would it work, if you load the state_dict
into the internal model
without using the Lightning method via liteBDRARobject.model.load_state_dict()
?
Hello @ptrblck , It works when I load it to the pytorch model without using Lightning as shown below:
def main():
net = BDRAR().cuda().train()
print("Network Created")
optimizer = optim.SGD([
{'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
'lr': 2 * args['lr']},
{'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
'lr': args['lr'], 'weight_decay': args['weight_decay']}
], momentum=args['momentum'])
print("Training Started.....")
if len(args['snapshot']) > 0:
print(('training resumes from \'%s\'' % args['snapshot']))
net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')),strict=False)
optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '_optim.pth')))
optimizer.param_groups[0]['lr'] = 2 * args['lr']
optimizer.param_groups[1]['lr'] = args['lr']
check_mkdir(ckpt_path)
check_mkdir(os.path.join(ckpt_path, exp_name))
open(log_path, 'w').write(str(args) + '\n\n')
print("Net is: ")
print(net)
I’m not familiar with Lightning enough, but are you able to access the “plain” PyTorch model from the Lightning wrapper and could then load the state_dict
into it directly?
Hi @ptrblck , many thanks for your reply. I appreciate it! I could access the plain PyTorch model from Lightning wrapper but I got an error when I tried to load the model.
I got help from another forum and I can load the pretrained model in Lightning now.
erict
(eric)
February 12, 2022, 2:38pm
6
Hi @JohnDuke259 ,
can you please share your solution on how to load pytorch model into lightning module? I’am looking for the same thing.
Thanks
1 Like
Hi @erict ,
Sorry I just saw this notification.
In your Lightning Class:
class liteCalc(pl.LightningModule):
def __init__(self):
super(liteCalc, self).__init__()
self.model = Calc()
def load_model(self, path):
self.model.load_state_dict(torch.load(path, map_location='cuda:0'), strict=False)
print('Model Created!')
def forward(self, x):
return self.model(x)
Outside Lightning class before you create your Trainer :
path = './ckpt/x.pth'
model = liteCalc()
model.load_model(path)
I hope this helps.
erict
(eric)
February 17, 2022, 9:37am
8
Hi @JohnDuke259 ,
yes this is working.
Thanks
Hi @erict , you are welcome!