Quantization performance is very different between fake quantized and true quantized

Dear all,

I’m trying to quantize a ResNet50. I have decided to use QAT to train a fake quantize version of the model. The training statistics are good, but after torch.quantization.convert(model, inplace=True) they are completely different. Indeed if I evaluate the model in the fake quantized version I obtain a F1 score of 0.99 while after convert it reaches only 0.70. I don’t think it is normal, there is something wrong in the code (I attach it below).

import torch 
import torchvision
from tqdm import tqdm
import sklearn
import numpy as np 
from torchvision import transforms
import random
from torchvision.datasets import ImageFolder
from torchvision.models.quantization import resnet50
import torchmetrics
import torch.nn.functional as F
from collections import Counter
#from quantized_models.resnet50_sicura import ResNet50_sicura
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Dataset
from observers.PACTFakeQuantised import PACTFakeQuantize
from split_data import generate_data_structure
from per_patient_sicura.Qresnet_pp import QResnet50pp
from per_patient_sicura.sicura_dataset_pp import SicuraPPDataset
import os
from torch.ao.quantization import FakeQuantize
from torch.ao.quantization.observer import MovingAveragePerChannelMinMaxObserver


bonus_dataset = '/home/mdatres/data/si-cura/data_bonus/dataset_bonus/CU_MC_new_datasets'
path_to_ourdata = '/home/mdatres/data/dataset_per_patient'

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
    
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1, p = 1):
        self.std = std
        self.mean = mean
        self.p = p
    def __call__(self, tensor):
        if random.uniform(0, 1) < self.p:
            return tensor + torch.randn(tensor.size()) * self.std + self.mean
        else: 
            return tensor
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

class Subset(Dataset):
    r"""
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset, indices, transform):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __getitem__(self, idx):
        
        im, labels = self.dataset[self.indices[idx]]
        
        return self.transform(im), labels

    def __len__(self):
        return len(self.indices)

def compute_class_weight(dataset: Subset, indices):
    """
    Compute the weight for each class that need to be passed to the sampler in the 
    torch.utils.data.WeightedRandomSampler when building the train_dataloader
    """
    dataset = dataset.dataset
    count_element_classes = Counter(dataset.targets[i] for i in indices) #store a the num of elem for each class {'Class1':num1, 'Class2':num2}
    l = np.array(list(count_element_classes.values())) #from a dictionary to a np,array [numClass1,numClass2]
    weight = l/(sum(count_element_classes.values())) #get the weigth for each class [weighClass1,weighClass2]
    samples_weight = np.array([weight[sample] for j,sample in enumerate(dataset.targets) if j in indices]) #get the vector of sample weights
    samples_weight=torch.from_numpy(samples_weight) #transform into tensor

    return samples_weight, torch.from_numpy(weight).float()
def compute_class_weight(dataset: Subset, indices):
    """
    Compute the weight for each class that need to be passed to the sampler in the 
    torch.utils.data.WeightedRandomSampler when building the train_dataloader
    """
    dataset = dataset.dataset
    count_element_classes = Counter(dataset.targets[i] for i in indices) #store a the num of elem for each class {'Class1':num1, 'Class2':num2}
    l = np.array(list(count_element_classes.values()))# from a dictionary to a np,array [numClass1,numClass2]
    weight = [sum(l)/r for r in l]
    weight = weight/(sum(weight))
      #get the weigth for each class [weighClass1,weighClass2]
    samples_weight = np.array([weight[sample] for j,sample in enumerate(dataset.targets) ]) #get the vector of sample weight
    samples_weight=torch.from_numpy(samples_weight)# transform into tensor

    return samples_weight, torch.from_numpy(weight).float()

all_val_loss = []
save_test = []


my_qconfig =  torch.quantization.get_default_qconfig("fbgemm")
    
#train transforms
train_transforms = transforms.Compose([
        transforms.Resize((256,256)),
        #random_colour_transform,
        transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4)],
                    p=0.6
                ),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        #AddGaussianNoise(std=0.01, p=0.5),
    #normalize,
    ])

#test transforms
test_transforms = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    #normalize,
])
batch_size = 64

train_val_data = ImageFolder(
        root=bonus_dataset)
label_train_val_data = [sample[1] for sample in train_val_data.samples]
epochs = 50
if torch.cuda.is_available():
        print('using cuda')
        device = torch.device("cuda")
else:
    print('using cpu')
    device = torch.device("cpu")

import wandb

wandb.init(project="sicura_quantized")
#5-fold stratified
kfold_stratified = StratifiedKFold(
    n_splits=10, shuffle=True)
for fold, (train_ids, val_ids) in tqdm(enumerate(kfold_stratified.split(train_val_data, label_train_val_data))):
    print("------"*20)
    print('Fold {} starting...'.format(fold + 1))
    min_val_loss = float('inf')
    #train-val split for each fold of the dataset
    train_dataset_cv = Subset(train_val_data, train_ids, transform= train_transforms)
    
    sample_weight_per_class_train, train_weight = compute_class_weight(train_dataset_cv, train_ids)
    
    #define a sampler for class imbalace in the training
    sampler = torch.utils.data.WeightedRandomSampler(sample_weight_per_class_train, len(sample_weight_per_class_train))
    train_data_loader = DataLoader(train_dataset_cv, batch_size=batch_size,shuffle=True, num_workers=0, pin_memory=True)
    val_dataset_cv = Subset(train_val_data, val_ids, transform= test_transforms)
    weight_per_class_val , val_weight= compute_class_weight(val_dataset_cv, val_ids)
    val_data_loader = DataLoader(
        val_dataset_cv, batch_size=batch_size, num_workers=8, shuffle=True, pin_memory=True)
    model = resnet50(weights="ResNet50_Weights.IMAGENET1K_V2")
    model.fc = torch.nn.Linear(model.fc.in_features, 2)
    model.load_state_dict(torch.load("/home/mdatres/sicura_SSL/best_fold_0.pth"))
    model.qconfig = my_qconfig
    print("Quantization according to the observers: " + str(model.qconfig))
    torch.quantization.prepare_qat(model, inplace=True)
    f1 = torchmetrics.F1Score(task='binary', num_classes=2).to(device)
    mcc = torchmetrics.MatthewsCorrCoef(task = 'binary').to(device)
    criterion = torch.nn.CrossEntropyLoss(weight=train_weight)

    # #freeze params
    # for params in model.parameters():
    #     params.requires_grad = False# disable autograd tracking
    
    # for params in model.layer3.parameters():
    #     params.requires_grad = True
    
    # #unfreeze params onlny in 3rd and 4th layers to retrain them
    # for params in model.layer4.parameters():
    #     params.requires_grad = True

    
    optimizer =  torch.optim.Adam(
           model.parameters(), lr = 1e-4)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3)
    if torch.cuda.is_available():
        print('using cuda')
        device = torch.device("cuda")
    else:
        print('using cpu')
        device = torch.device("cpu")
    model.to(device)
    criterion.to(device)
    early_stopper = EarlyStopper(patience=1, min_delta=0.0)
    for epoch in tqdm(range(epochs + 1)):
        model.train()
        train_loss = 0.0
        for data, target in train_data_loader:
            
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            if epoch>=5:
                model.apply(torch.quantization.disable_observer)
            if epoch > 7:
               # Freeze batch norm mean and variance estimates
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
            output = model(data)
            l1_norm_all = torch.tensor([0.0]).to(device)
            for p in model.parameters():
                if p.requires_grad:
                    l1_norm_all += p.abs().sum()
            loss = criterion(output, target) #+ 1e-5*l1_norm_all
            loss.backward()
            optimizer.step()
            train_loss += loss.data.item() * data.size(0)
            f1.update(torch.argmax(output.softmax(dim=-1),1), target)
            mcc.update(torch.argmax(output.softmax(dim=-1),1), target)
        
        wandb.log({'lr_' + str(fold):  optimizer.state_dict()['param_groups'][0]['lr']})
        f1_save = f1.compute()
        mcc_save = mcc.compute()
        train_loss /= len(train_data_loader.dataset)
        
        print('Epoch_Train:{}, T_Loss:{:.2f}, F1_score= {:.2f}, MCC= {:.4f}'
                .format(epoch, train_loss, f1_save, mcc_save ))
        
        wandb.log({'f1_train'+ '_fold_'  + str(fold): f1_save.item(), 'mcc_train'  + '_fold_' + str(fold): mcc_save.item(), 'loss_train'+'_fold_' + str(fold): train_loss})
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for data, target in val_data_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
               # get the index of the max log-probability
                loss = F.cross_entropy(output, target, weight=val_weight.to(device))
                l1_norm_all = torch.tensor(0.0).to(device)
                for p in model.parameters():
                    if p.requires_grad:
                        l1_norm_all += p.abs().sum()
                #loss += 1e-5*l1_norm_all
                f1.update(torch.argmax(output.softmax(dim=-1),1), target)
                mcc.update(torch.argmax(output.softmax(dim=-1),1), target)
                val_loss += loss.data.item() * data.size(0)
        
        val_loss = val_loss/len(val_data_loader.dataset)
        lr_scheduler.step(metrics=val_loss)
        f1val = f1.compute()
        mccval = mcc.compute()
        print('Epoch_Val:{}, Val_Loss:{:.2f}, F1= {:.2f}, MCC= {:.4f}'
                .format(epoch, val_loss, f1val, mccval ))
        
        wandb.log({'acc_val' + '_fold_' +  str(fold): f1val.item(), 'mcc_val' + '_fold_'+ str(fold): mccval.item(), 'loss_val' + '_fold_' + str(fold): val_loss})
        if early_stopper.early_stop(val_loss):             
            break
        if val_loss < min_val_loss: 
            print("The best is at epoch " + str(epoch))
            torch.save(model.state_dict(), '/home/mdatres/sicura_SSL/result_pact/bestq' + '_fold_'+str(fold) + '.pth')
            min_val_loss = val_loss
        all_val_loss.append((min_val_loss, '/home/mdatres/sicura_SSL/result_pact/bestq'  + '_fold_'+str(fold) + '.pth'))

#Testing phase 

model = resnet50()
model.fc = torch.nn.Linear(model.fc.in_features, 2)

model.qconfig = my_qconfig
print("Quantization according to the observers: " + str(model.qconfig))
torch.quantization.prepare_qat(model, inplace=True)
data = SicuraPPDataset(path_to_ourdata, test_transforms)
def second_item(data):
        return data[0]
all_val_loss.sort(key=second_item, reverse=True)
with open('/home/mdatres/sicura_SSL/result_pact/best_all.txt', 'w') as f:
    f.write('The best choosen model is ' + all_val_loss[0][1])
model.load_state_dict(torch.load(all_val_loss[0][1]))
model.to("cpu")
model.eval()
torch.quantization.convert(model, inplace=True)
netpp = QResnet50pp(threshold=0.7,model=model)

netpp.to("cpu")
netpp.eval()

f1_test = torchmetrics.F1Score(task='binary').to("cpu")
mcc_test = torchmetrics.MatthewsCorrCoef(task = 'binary').to("cpu")
for d in data:
    o = netpp(input=d, device = "cpu", k=1)
    f1_test.update(torch.tensor([o]).to("cpu"), torch.tensor([d[1]]).to("cpu"))
    mcc_test.update(torch.tensor([o]).to("cpu"), torch.tensor([d[1]]).to("cpu"))

f1_final = f1_test.compute()
mcc_final = mcc_test.compute()
print('Test Statistic'+':  '+ 'F1_score:  ' +str(f1_final.item())+ '  MCC_score:  ' +str(mcc_final.item()))
save_test.append({"F1_score": str(f1_final.item()), "MCC_score": str(mcc_final.item())})

Thanks in advance,
mdatres

Did you succeed? I had the same problem doing QAT on other models