I am not sure I can load the checkpoint like that since Lightning takes care of that for me (unless there’s a way to override it?)
Here’s some context (reduced/simplified as much as possible). I am using the CLI and all my config is driven from a config.yaml
file.
main.py
cli = LightningCLI(..., run=False)
preds = cli.trainer.predict(cli.model, cli.datamodule, return_predictions=True, ckpt_path='<my_ckpt.pth.tar>')
Model
import lightning.pytorch as pl
import segmentation_models_pytorch as smp
import torch.nn as nn
class DeepLabV3Plus(pl.LightningModule):
def __init__(self, num_classes, encoder_name, encoder_depth,
encoder_weights, activation, loss_fn):
super().__init__()
self.model = smp.DeepLabV3Plus(encoder_name=encoder_name,
encoder_depth=encoder_depth,
encoder_weights=encoder_weights,
classes=num_classes,
activation=activation)
self.loss_fn = nn.CrossEntropyLoss()
self.save_hyperparameters()
def configure_optimizers(self) -> Any:
return super().configure_optimizers()
def forward(self, x):
preds = self.model(x)
return preds
def training_step(self, batch, batch_idx):
img, gt = batch
output = self.model(img)
loss = self.loss_fn(output, gt)
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
img, gt = batch
output = self.model(img)
loss = self.loss_fn(output, gt)
self.log('val_loss', loss, on_epoch=True, prog_bar=True)
def test_step(self, batch, batch_idx):
img, gt = batch
output = self.model(img)
loss = self.loss_fn(output, gt)
self.log('test_loss', loss, on_epoch=True, prog_bar=True)
In a separate script I attempted to fix the checkpoint file like this:
from collections import OrderedDict
import torch
# Load the checkpoint
checkpoint_path = "checkpoints/my_checkpoint.pth.tar"
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# create new checkpoint
modified_checkpoint = OrderedDict()
modified_checkpoint['state_dict'] = OrderedDict()
# prefix keys in state_dict
state_dict = checkpoint['state_dict'].copy()
for k, v in state_dict.items():
k = f'model.{k}'
modified_checkpoint['state_dict'][k] = v
# add missing keys
modified_checkpoint['pytorch-lightning_version'] = '0.0.0'
modified_checkpoint['global_step'] = None
modified_checkpoint['epoch'] = None
modified_checkpoint['state_dict']['predict_loop'] = None
# save
modified_checkpoint_path = "my_checkpoint_edited.pth.tar"
torch.save(modified_checkpoint, modified_checkpoint_path)