How to load a Quantised model in PyTorch or PyTorch lightning?

I applied Quantisation aware training using PyTorch lightning on one of the architectures for faster inference, The model has been trained successfully but I am facing model loading issues during inference. I’ve come across a few forums with this same issue but couldn’t find a satisfactory method that can resolve my issue. Any help would be highly appreciated, Thanks!

Below is the attached code (I have trained the QAT model using PyTorch lightning but the issue arises when I try to load it)

Training Code

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import QuantizationAwareTraining


class Module(pl.LightningModule):

    def __init__(self, model):
        super().__init__()
        self.save_hyperparameters()

        self.model = Model(3).to(device)

        self.lr = lr
  

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def training_step(self, batch, batch_idx):
        lr, hr = batch
        sr = self(lr)
        loss = F.mse_loss(sr, hr, reduction="mean")

        return loss

    def validation_step(self, batch, batch_idx):
        lr, hr = batch
        sr = self(lr)
        loss = F.mse_loss(sr, hr, reduction="mean")

        return loss

    def test_step(self, batch, batch_idx):
        
        lr, hr = batch
        sr = self(lr)
        loss = F.mse_loss(sr, hr, reduction="mean")

        return loss


if __name__ == '__main__':


    scale_factor = 3
    batch_size = 24
            
    epochs = 1
    lr = 1e-5
        
    input_image_path = '....'  
    target_image_path = '....'

    val_input_path = '....'
    val_target_path = '...'
    
    prev_ckpt_path = '...'
    device = 'cpu' # Device kept as CPU for Quantisation Aware Training, as it doesnt support GPU
    
    # Define model
    model = SRModel(scale_factor).to('cpu')
    
    module = Module(model).load_from_checkpoint(prev_ckpt_path)

    # Setup dataloaders


    train_dataset = CustomDataset(input_image_path, target_image_path)
    
    training_dataloader = DataLoader(
                dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)

    val_dataset = CustomDataset(val_input_path, val_target_path)
    
    val_dataloader = DataLoader(
        dataset=val_dataset, num_workers=4, batch_size=batch_size, shuffle=False)
    

    checkpoint_callback = ModelCheckpoint(monitor='val_loss')

    trainer = pl.Trainer(max_epochs=epochs, gpus=0,  auto_lr_find=True,                         
                        logger= wandb_logger, progress_bar_refresh_rate = 3,
                        callbacks=[QuantizationAwareTraining(observer_type='histogram', input_compatible=True), checkpoint_callback])


    trainer.fit(
        module,
        training_dataloader,
        val_dataloader
    )
    

    trainer.save_checkpoint("Quantised.pth")
    trainer.save_checkpoint("Quantised.ckpt")

And this is the inference code :point_down:

Inference


class Module(pl.LightningModule):

    def __init__(self, model):
        super().__init__()
        self.save_hyperparameters()

        self.model = SRModel(3).to(device)

    def forward(self, x):
        return self.model(x)


prev_ckpt_path = '.....ckpt'

device = 'cpu'

# Define model
model = SRModel(3).to(device)

module = Module(model).load_from_checkpoint(prev_ckpt_path, strict=False)

RunTime Error: an exception occurred : (‘Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor’,).

This is a bit tricky at the moment (also Can't load_from_checkpoint with QuantizationAwareTraining callback during training · Issue #6457 · PyTorchLightning/pytorch-lightning · GitHub) and the problem is that the the model isn’t quantized yet.
If you model permits it, the easiest can be to save the scripted model (torch.jit.save(module.to_torchscript(), 'my_model.pt')), but there, too, there is the slight caveat that you need to call the quantization / dequantization yourself ( model.to_torchscript "forgets" input quantization that is automatically called for the model · Issue #7552 · PyTorchLightning/pytorch-lightning · GitHub ).

Edit: to call quantization/dequantization: model.dequant(model(model.quant(inp))). This might break when PL gets fixed, but should do the trick until then.

I hope this helps.

Best regards

Thomas

Thanks! @tom I think I need to test using torchscript and infer if the issue still exists.

Yeah that first link makes a lot of sense, this issue is common since the model needs to be prepared to do QAT but the prepare step happens at some random point within the PL framework (not at start) so loading into it before the prepare is what causes the error.

This is primarily a PL issue, but you should be able to circumvent it by creating a new PL module that does the quantization prepare stuff right after model load rather than at the checkpoint hook. You could for example call it manually before loading (though there may be issues with preparing a model twice if the quant prepare step in PL doesn’t check for that).

so are you suggesting something like self.model = quant(model) (during inference) and then load the model weights??

No, you’d need to do the pre-inference model preparation, i.e. the fuse and prepare steps.

This is the code that PL runs to prepare the model: pytorch-lightning/quantization.py at 92cf396de2fe49e89a625a200d641bd8b6aeb328 · PyTorchLightning/pytorch-lightning · GitHub

This is what needs to be run in order to load the checkpoint since the checkpoint is for the model after its been fused/prepared.

Figuring out how to do this would require PL expertise that I don’t have, it may be a good idea to ask at the PL forum: https://forums.pytorchlightning.ai/

1 Like

Ohkay no worries! Thanks a lot!