ResNet50 Quantization on medical data

Dear all,

I am new to quantization but I am interesting into trying how it works on medical images. I have tried to follow the pytorch tutorial on mnist and everything works fine. Hence I tried with my data, but the results in terms of performance are terrible. I have tried all the possible quantiziation pipelines (static/dynamic/aware) but anyone seems to work on my task. Below I attach the code that I have used for static quantization.

import torch 
import torchvision
from tqdm import tqdm
import sklearn
import numpy as np 
from torchvision import transforms
import random
from torchvision.datasets import ImageFolder
from torchvision.models.quantization import resnet50
import torchmetrics
import torch.nn.functional as F
from collections import Counter
import torch.quantization._numeric_suite as ns
from torch.utils.data import DataLoader, Dataset
from per_patient_sicura.Qresnet_pp import QResnet50pp
from per_patient_sicura.sicura_dataset_pp import SicuraPPDataset
import os


bonus_dataset = '/home/mdatres/data/si-cura/data_bonus/dataset_bonus/CU_MC_new_datasets'
path_to_ourdata = '/home/mdatres/data/dataset_per_patient'

def compute_class_weight(dataset: Dataset):
    """
    Compute the weight for each class that need to be passed to the sampler in the 
    torch.utils.data.WeightedRandomSampler when building the train_dataloader
    """
    
    count_element_classes = Counter(dataset.targets) # store a the num of elem for each class {'Class1':num1, 'Class2':num2}
    l = np.array(list(count_element_classes.values())) # from a dictionary to a np,array [numClass1,numClass2]
    weight = [sum(l)/r for r in l]
    weight = weight/(sum(weight))
      # get the weigth for each class [weighClass1,weighClass2]
    samples_weight = np.array([weight[sample] for j,sample in enumerate(dataset.targets) ]) # get the vector of sample weight
    samples_weight=torch.from_numpy(samples_weight) # transform into tensor

    return samples_weight, torch.from_numpy(weight).float()
all_val_loss = []
save_test = []

# test transforms
test_transforms = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    #normalize,
])
batch_size = 1

train_val_data = ImageFolder(
        root=bonus_dataset, transform=test_transforms)
sample_weight_per_class_train, train_weight = compute_class_weight(train_val_data)

sampler = torch.utils.data.WeightedRandomSampler(sample_weight_per_class_train, len(sample_weight_per_class_train))
train_data_loader = DataLoader(train_val_data, batch_size=batch_size,sampler = sampler, num_workers=0, pin_memory=True)
label_train_val_data = [sample[1] for sample in train_val_data.samples]
epochs = 50
if torch.cuda.is_available():
        print('using cuda')
        device = torch.device("cuda")
else:
    print('using cpu')
    device = torch.device("cpu")
model = resnet50()
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("/home/mdatres/sicura_SSL/best_fold_0.pth", map_location="cpu"))

backend = "fbgemm"
my_qconfig = torch.quantization.get_default_qconfig(backend)
#my_qconfig = torch.ao.quantization.default_qconfig
model.qconfig = my_qconfig
model.to(device)
model.eval()
model.fuse_model()
print(model)
print("Quantization according to the observers: " + str(model.qconfig))
torch.quantization.prepare(model, inplace=True)

samples_for_class = [0,0]
k=0
model.eval()
for d in train_data_loader: 
    if k<10000:
        x, y = d
        model(x.to(device))
        if y.item() == 0:
            samples_for_class[0] = samples_for_class[0] +1
        else:
            samples_for_class[1] = samples_for_class[1] + 1
        k+=1
    else:
        break

print("Calibration successfully ended with a calbration set with this number of labels " + str(samples_for_class))
model.to("cpu")

model_q = torch.quantization.convert(model, inplace=False)
wt_compare_dict = ns.compare_weights(model.state_dict(), model_q.state_dict())
def compute_error(x, y):
    Ps = torch.norm(x)
    Pn = torch.norm(x-y)
    return 20*torch.log10(Ps/Pn)

for key in wt_compare_dict:
    print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))

print('keys of wt_compare_dict:')
print(wt_compare_dict.keys())

print("\nkeys of wt_compare_dict entry for conv1's weight:")
print(wt_compare_dict['conv1.weight'].keys())
print(wt_compare_dict['conv1.weight']['float'].shape)
print(wt_compare_dict['conv1.weight']['quantized'].shape)

The printed errors are around 30 (is this quantization error high?). Do you have any suggestion on how to improve the performance on this task? Am I doing something wrong?

Best,
mdatres

How does the quantized model perform on your evaluation set? Usually that is a more representative measure of how well the quantized model approximates the fp32 model.

The analysis you have done is usually used for debugging which parts of the model are more sensitive to quantization (see PyTorch Numeric Suite Tutorial — PyTorch Tutorials 1.13.1+cu117 documentation ), but the first step is usually to get an evaluation score.

The performance of the quantized model on the evaluation set are bad. Indeed the floating point model reaches a mcc of 0.9, while its quantized version has a mcc of 0.35. I have tried also qat but the situation is the same. Indeed during the fake quantized model evaluation the mcc is around 0.9, but once I quantize it the mcc is 0.4. I cannot understand why there is such a difference between the true and the fake quantized model. Do you have any suggestion?
Thanks for the reply!
Max