Issues using a non-Lightning checkpoint in Lightning

I have a checkpoint that was trained with a standard Pytorch implementation. The model used was DeepLabV3Plus from the segmentation_models_pytorch library.

I am trying to load the checkpoint with Pytorch Lightning but I am running into a few issues.

First I was getting KeyErrors for pytorch-lightning_version, global_step and epoch. I set these to dummy values.

Then, I was getting the “Unexpected / Missing keys in state_dict” issue which I solved by prefixing the keys in the state_dict with model. (and removing the old ones).

But now I am stuck. I get the error:

self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"])
KeyError: 'predict_loop'

Ok, so I added the predict_loop key to the state_dict and set it to None.

Now I get:

RuntimeError: Error(s) in loading state_dict for DeepLabV3Plus:
        Unexpected key(s) in state_dict: "predict_loop". 

I assume the value of predict_loop can’t be None but then again that would make this error misleading.

Any ideas on how to solve this?

I’m not exactly sure what you are doing from your description, but maybe loading the parameters to the model could be a good way, so foo.model.load_state_dict() if foo.model is the original model.

Best regards

Thomas

I am not sure I can load the checkpoint like that since Lightning takes care of that for me (unless there’s a way to override it?)

Here’s some context (reduced/simplified as much as possible). I am using the CLI and all my config is driven from a config.yaml file.

main.py

cli = LightningCLI(..., run=False)
preds = cli.trainer.predict(cli.model, cli.datamodule, return_predictions=True, ckpt_path='<my_ckpt.pth.tar>')

Model

import lightning.pytorch as pl
import segmentation_models_pytorch as smp
import torch.nn as nn

class DeepLabV3Plus(pl.LightningModule):
    def __init__(self, num_classes, encoder_name, encoder_depth,
                 encoder_weights, activation, loss_fn):
        super().__init__()

        self.model = smp.DeepLabV3Plus(encoder_name=encoder_name,
                              encoder_depth=encoder_depth,
                              encoder_weights=encoder_weights,
                              classes=num_classes,
                              activation=activation)

        self.loss_fn = nn.CrossEntropyLoss()
        self.save_hyperparameters()

    def configure_optimizers(self) -> Any:
        return super().configure_optimizers()

    def forward(self, x):
        preds = self.model(x)
        return preds

    def training_step(self, batch, batch_idx):
        img, gt = batch
        output = self.model(img)
        loss = self.loss_fn(output, gt)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        img, gt = batch
        output = self.model(img)
        loss = self.loss_fn(output, gt)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        img, gt = batch
        output = self.model(img)
        loss = self.loss_fn(output, gt)
        self.log('test_loss', loss, on_epoch=True, prog_bar=True)

In a separate script I attempted to fix the checkpoint file like this:

from collections import OrderedDict
import torch

# Load the checkpoint
checkpoint_path = "checkpoints/my_checkpoint.pth.tar"
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# create new checkpoint
modified_checkpoint = OrderedDict()
modified_checkpoint['state_dict'] = OrderedDict()

# prefix keys in state_dict
state_dict = checkpoint['state_dict'].copy()
for k, v in state_dict.items():
    k = f'model.{k}'
    modified_checkpoint['state_dict'][k] = v

# add missing keys
modified_checkpoint['pytorch-lightning_version'] = '0.0.0'
modified_checkpoint['global_step'] = None
modified_checkpoint['epoch'] = None
modified_checkpoint['state_dict']['predict_loop'] = None

# save
modified_checkpoint_path = "my_checkpoint_edited.pth.tar"
torch.save(modified_checkpoint, modified_checkpoint_path)

How about loading the checkpoint to the (inner) model and then saving the full lightning model to get a lightning checkpoint?

Did you ever figure this out? I’ve got a really similar problem except my model I trained didn’t use lightning but the script I’m using to generate output from the model is expecting a ckpt from a lightning module. I’m going to brazenly steal your deeplab3plus code above to see if I can get it rerunning as a lightning module. Thanks if you have any comments. Just learning some lightning but I’m in the endless, learn what you need to keep going cycle.

As I remember, I couldn’t get it working for my DeepLabV3Plus model, but I recently did (almost) the same thing for a Faster RCNN checkpoint and it worked.

This is more or less my code:

from collections import OrderedDict

import torch

# Load the checkpoint
checkpoint_path = "broken.ckpt"
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# create new checkpoint
modified_checkpoint = OrderedDict()
modified_checkpoint['state_dict'] = {}

# prefix keys in state_dict
for k, v in checkpoint.items():
    k = f'model.{k}'
    modified_checkpoint['state_dict'][k] = v

# add missing keys
modified_checkpoint['pytorch-lightning_version'] = '0.0.0'
modified_checkpoint['global_step'] = None
modified_checkpoint['epoch'] = None

# save
modified_checkpoint_path = "fixed.ckpt"
torch.save(modified_checkpoint, modified_checkpoint_path)