Dear all,
I am new to quantization but I am interesting into trying how it works on medical images. I have tried to follow the pytorch tutorial on mnist and everything works fine. Hence I tried with my data, but the results in terms of performance are terrible. I have tried all the possible quantiziation pipelines (static/dynamic/aware) but anyone seems to work on my task. Below I attach the code that I have used for static quantization.
import torch
import torchvision
from tqdm import tqdm
import sklearn
import numpy as np
from torchvision import transforms
import random
from torchvision.datasets import ImageFolder
from torchvision.models.quantization import resnet50
import torchmetrics
import torch.nn.functional as F
from collections import Counter
import torch.quantization._numeric_suite as ns
from torch.utils.data import DataLoader, Dataset
from per_patient_sicura.Qresnet_pp import QResnet50pp
from per_patient_sicura.sicura_dataset_pp import SicuraPPDataset
import os
bonus_dataset = '/home/mdatres/data/si-cura/data_bonus/dataset_bonus/CU_MC_new_datasets'
path_to_ourdata = '/home/mdatres/data/dataset_per_patient'
def compute_class_weight(dataset: Dataset):
"""
Compute the weight for each class that need to be passed to the sampler in the
torch.utils.data.WeightedRandomSampler when building the train_dataloader
"""
count_element_classes = Counter(dataset.targets) # store a the num of elem for each class {'Class1':num1, 'Class2':num2}
l = np.array(list(count_element_classes.values())) # from a dictionary to a np,array [numClass1,numClass2]
weight = [sum(l)/r for r in l]
weight = weight/(sum(weight))
# get the weigth for each class [weighClass1,weighClass2]
samples_weight = np.array([weight[sample] for j,sample in enumerate(dataset.targets) ]) # get the vector of sample weight
samples_weight=torch.from_numpy(samples_weight) # transform into tensor
return samples_weight, torch.from_numpy(weight).float()
all_val_loss = []
save_test = []
# test transforms
test_transforms = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor(),
#normalize,
])
batch_size = 1
train_val_data = ImageFolder(
root=bonus_dataset, transform=test_transforms)
sample_weight_per_class_train, train_weight = compute_class_weight(train_val_data)
sampler = torch.utils.data.WeightedRandomSampler(sample_weight_per_class_train, len(sample_weight_per_class_train))
train_data_loader = DataLoader(train_val_data, batch_size=batch_size,sampler = sampler, num_workers=0, pin_memory=True)
label_train_val_data = [sample[1] for sample in train_val_data.samples]
epochs = 50
if torch.cuda.is_available():
print('using cuda')
device = torch.device("cuda")
else:
print('using cpu')
device = torch.device("cpu")
model = resnet50()
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("/home/mdatres/sicura_SSL/best_fold_0.pth", map_location="cpu"))
backend = "fbgemm"
my_qconfig = torch.quantization.get_default_qconfig(backend)
#my_qconfig = torch.ao.quantization.default_qconfig
model.qconfig = my_qconfig
model.to(device)
model.eval()
model.fuse_model()
print(model)
print("Quantization according to the observers: " + str(model.qconfig))
torch.quantization.prepare(model, inplace=True)
samples_for_class = [0,0]
k=0
model.eval()
for d in train_data_loader:
if k<10000:
x, y = d
model(x.to(device))
if y.item() == 0:
samples_for_class[0] = samples_for_class[0] +1
else:
samples_for_class[1] = samples_for_class[1] + 1
k+=1
else:
break
print("Calibration successfully ended with a calbration set with this number of labels " + str(samples_for_class))
model.to("cpu")
model_q = torch.quantization.convert(model, inplace=False)
wt_compare_dict = ns.compare_weights(model.state_dict(), model_q.state_dict())
def compute_error(x, y):
Ps = torch.norm(x)
Pn = torch.norm(x-y)
return 20*torch.log10(Ps/Pn)
for key in wt_compare_dict:
print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
print('keys of wt_compare_dict:')
print(wt_compare_dict.keys())
print("\nkeys of wt_compare_dict entry for conv1's weight:")
print(wt_compare_dict['conv1.weight'].keys())
print(wt_compare_dict['conv1.weight']['float'].shape)
print(wt_compare_dict['conv1.weight']['quantized'].shape)
The printed errors are around 30 (is this quantization error high?). Do you have any suggestion on how to improve the performance on this task? Am I doing something wrong?
Best,
mdatres