I applied Quantisation aware training using PyTorch lightning on one of the architectures for faster inference, The model has been trained successfully but I am facing model loading issues during inference. I’ve come across a few forums with this same issue but couldn’t find a satisfactory method that can resolve my issue. Any help would be highly appreciated, Thanks!
Below is the attached code (I have trained the QAT model using PyTorch lightning but the issue arises when I try to load it)
Training Code
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import QuantizationAwareTraining
class Module(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.save_hyperparameters()
self.model = Model(3).to(device)
self.lr = lr
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def training_step(self, batch, batch_idx):
lr, hr = batch
sr = self(lr)
loss = F.mse_loss(sr, hr, reduction="mean")
return loss
def validation_step(self, batch, batch_idx):
lr, hr = batch
sr = self(lr)
loss = F.mse_loss(sr, hr, reduction="mean")
return loss
def test_step(self, batch, batch_idx):
lr, hr = batch
sr = self(lr)
loss = F.mse_loss(sr, hr, reduction="mean")
return loss
if __name__ == '__main__':
scale_factor = 3
batch_size = 24
epochs = 1
lr = 1e-5
input_image_path = '....'
target_image_path = '....'
val_input_path = '....'
val_target_path = '...'
prev_ckpt_path = '...'
device = 'cpu' # Device kept as CPU for Quantisation Aware Training, as it doesnt support GPU
# Define model
model = SRModel(scale_factor).to('cpu')
module = Module(model).load_from_checkpoint(prev_ckpt_path)
# Setup dataloaders
train_dataset = CustomDataset(input_image_path, target_image_path)
training_dataloader = DataLoader(
dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
val_dataset = CustomDataset(val_input_path, val_target_path)
val_dataloader = DataLoader(
dataset=val_dataset, num_workers=4, batch_size=batch_size, shuffle=False)
checkpoint_callback = ModelCheckpoint(monitor='val_loss')
trainer = pl.Trainer(max_epochs=epochs, gpus=0, auto_lr_find=True,
logger= wandb_logger, progress_bar_refresh_rate = 3,
callbacks=[QuantizationAwareTraining(observer_type='histogram', input_compatible=True), checkpoint_callback])
trainer.fit(
module,
training_dataloader,
val_dataloader
)
trainer.save_checkpoint("Quantised.pth")
trainer.save_checkpoint("Quantised.ckpt")
And this is the inference code
Inference
class Module(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.save_hyperparameters()
self.model = SRModel(3).to(device)
def forward(self, x):
return self.model(x)
prev_ckpt_path = '.....ckpt'
device = 'cpu'
# Define model
model = SRModel(3).to(device)
module = Module(model).load_from_checkpoint(prev_ckpt_path, strict=False)
RunTime Error: an exception occurred : (‘Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor’,).