Dear all,
I’m trying to quantize a ResNet50. I have decided to use QAT to train a fake quantize version of the model. The training statistics are good, but after torch.quantization.convert(model, inplace=True)
they are completely different. Indeed if I evaluate the model in the fake quantized version I obtain a F1 score of 0.99 while after convert it reaches only 0.70. I don’t think it is normal, there is something wrong in the code (I attach it below).
import torch
import torchvision
from tqdm import tqdm
import sklearn
import numpy as np
from torchvision import transforms
import random
from torchvision.datasets import ImageFolder
from torchvision.models.quantization import resnet50
import torchmetrics
import torch.nn.functional as F
from collections import Counter
#from quantized_models.resnet50_sicura import ResNet50_sicura
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, Dataset
from observers.PACTFakeQuantised import PACTFakeQuantize
from split_data import generate_data_structure
from per_patient_sicura.Qresnet_pp import QResnet50pp
from per_patient_sicura.sicura_dataset_pp import SicuraPPDataset
import os
from torch.ao.quantization import FakeQuantize
from torch.ao.quantization.observer import MovingAveragePerChannelMinMaxObserver
bonus_dataset = '/home/mdatres/data/si-cura/data_bonus/dataset_bonus/CU_MC_new_datasets'
path_to_ourdata = '/home/mdatres/data/dataset_per_patient'
class EarlyStopper:
def __init__(self, patience=1, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.min_validation_loss = np.inf
def early_stop(self, validation_loss):
if validation_loss < self.min_validation_loss:
self.min_validation_loss = validation_loss
self.counter = 0
elif validation_loss > (self.min_validation_loss + self.min_delta):
self.counter += 1
if self.counter >= self.patience:
return True
return False
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1, p = 1):
self.std = std
self.mean = mean
self.p = p
def __call__(self, tensor):
if random.uniform(0, 1) < self.p:
return tensor + torch.randn(tensor.size()) * self.std + self.mean
else:
return tensor
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class Subset(Dataset):
r"""
Subset of a dataset at specified indices.
Arguments:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
def __init__(self, dataset, indices, transform):
self.dataset = dataset
self.indices = indices
self.transform = transform
def __getitem__(self, idx):
im, labels = self.dataset[self.indices[idx]]
return self.transform(im), labels
def __len__(self):
return len(self.indices)
def compute_class_weight(dataset: Subset, indices):
"""
Compute the weight for each class that need to be passed to the sampler in the
torch.utils.data.WeightedRandomSampler when building the train_dataloader
"""
dataset = dataset.dataset
count_element_classes = Counter(dataset.targets[i] for i in indices) #store a the num of elem for each class {'Class1':num1, 'Class2':num2}
l = np.array(list(count_element_classes.values())) #from a dictionary to a np,array [numClass1,numClass2]
weight = l/(sum(count_element_classes.values())) #get the weigth for each class [weighClass1,weighClass2]
samples_weight = np.array([weight[sample] for j,sample in enumerate(dataset.targets) if j in indices]) #get the vector of sample weights
samples_weight=torch.from_numpy(samples_weight) #transform into tensor
return samples_weight, torch.from_numpy(weight).float()
def compute_class_weight(dataset: Subset, indices):
"""
Compute the weight for each class that need to be passed to the sampler in the
torch.utils.data.WeightedRandomSampler when building the train_dataloader
"""
dataset = dataset.dataset
count_element_classes = Counter(dataset.targets[i] for i in indices) #store a the num of elem for each class {'Class1':num1, 'Class2':num2}
l = np.array(list(count_element_classes.values()))# from a dictionary to a np,array [numClass1,numClass2]
weight = [sum(l)/r for r in l]
weight = weight/(sum(weight))
#get the weigth for each class [weighClass1,weighClass2]
samples_weight = np.array([weight[sample] for j,sample in enumerate(dataset.targets) ]) #get the vector of sample weight
samples_weight=torch.from_numpy(samples_weight)# transform into tensor
return samples_weight, torch.from_numpy(weight).float()
all_val_loss = []
save_test = []
my_qconfig = torch.quantization.get_default_qconfig("fbgemm")
#train transforms
train_transforms = transforms.Compose([
transforms.Resize((256,256)),
#random_colour_transform,
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4)],
p=0.6
),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ToTensor(),
#AddGaussianNoise(std=0.01, p=0.5),
#normalize,
])
#test transforms
test_transforms = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor(),
#normalize,
])
batch_size = 64
train_val_data = ImageFolder(
root=bonus_dataset)
label_train_val_data = [sample[1] for sample in train_val_data.samples]
epochs = 50
if torch.cuda.is_available():
print('using cuda')
device = torch.device("cuda")
else:
print('using cpu')
device = torch.device("cpu")
import wandb
wandb.init(project="sicura_quantized")
#5-fold stratified
kfold_stratified = StratifiedKFold(
n_splits=10, shuffle=True)
for fold, (train_ids, val_ids) in tqdm(enumerate(kfold_stratified.split(train_val_data, label_train_val_data))):
print("------"*20)
print('Fold {} starting...'.format(fold + 1))
min_val_loss = float('inf')
#train-val split for each fold of the dataset
train_dataset_cv = Subset(train_val_data, train_ids, transform= train_transforms)
sample_weight_per_class_train, train_weight = compute_class_weight(train_dataset_cv, train_ids)
#define a sampler for class imbalace in the training
sampler = torch.utils.data.WeightedRandomSampler(sample_weight_per_class_train, len(sample_weight_per_class_train))
train_data_loader = DataLoader(train_dataset_cv, batch_size=batch_size,shuffle=True, num_workers=0, pin_memory=True)
val_dataset_cv = Subset(train_val_data, val_ids, transform= test_transforms)
weight_per_class_val , val_weight= compute_class_weight(val_dataset_cv, val_ids)
val_data_loader = DataLoader(
val_dataset_cv, batch_size=batch_size, num_workers=8, shuffle=True, pin_memory=True)
model = resnet50(weights="ResNet50_Weights.IMAGENET1K_V2")
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("/home/mdatres/sicura_SSL/best_fold_0.pth"))
model.qconfig = my_qconfig
print("Quantization according to the observers: " + str(model.qconfig))
torch.quantization.prepare_qat(model, inplace=True)
f1 = torchmetrics.F1Score(task='binary', num_classes=2).to(device)
mcc = torchmetrics.MatthewsCorrCoef(task = 'binary').to(device)
criterion = torch.nn.CrossEntropyLoss(weight=train_weight)
# #freeze params
# for params in model.parameters():
# params.requires_grad = False# disable autograd tracking
# for params in model.layer3.parameters():
# params.requires_grad = True
# #unfreeze params onlny in 3rd and 4th layers to retrain them
# for params in model.layer4.parameters():
# params.requires_grad = True
optimizer = torch.optim.Adam(
model.parameters(), lr = 1e-4)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3)
if torch.cuda.is_available():
print('using cuda')
device = torch.device("cuda")
else:
print('using cpu')
device = torch.device("cpu")
model.to(device)
criterion.to(device)
early_stopper = EarlyStopper(patience=1, min_delta=0.0)
for epoch in tqdm(range(epochs + 1)):
model.train()
train_loss = 0.0
for data, target in train_data_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
if epoch>=5:
model.apply(torch.quantization.disable_observer)
if epoch > 7:
# Freeze batch norm mean and variance estimates
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
output = model(data)
l1_norm_all = torch.tensor([0.0]).to(device)
for p in model.parameters():
if p.requires_grad:
l1_norm_all += p.abs().sum()
loss = criterion(output, target) #+ 1e-5*l1_norm_all
loss.backward()
optimizer.step()
train_loss += loss.data.item() * data.size(0)
f1.update(torch.argmax(output.softmax(dim=-1),1), target)
mcc.update(torch.argmax(output.softmax(dim=-1),1), target)
wandb.log({'lr_' + str(fold): optimizer.state_dict()['param_groups'][0]['lr']})
f1_save = f1.compute()
mcc_save = mcc.compute()
train_loss /= len(train_data_loader.dataset)
print('Epoch_Train:{}, T_Loss:{:.2f}, F1_score= {:.2f}, MCC= {:.4f}'
.format(epoch, train_loss, f1_save, mcc_save ))
wandb.log({'f1_train'+ '_fold_' + str(fold): f1_save.item(), 'mcc_train' + '_fold_' + str(fold): mcc_save.item(), 'loss_train'+'_fold_' + str(fold): train_loss})
model.eval()
val_loss = 0.0
with torch.no_grad():
for data, target in val_data_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# get the index of the max log-probability
loss = F.cross_entropy(output, target, weight=val_weight.to(device))
l1_norm_all = torch.tensor(0.0).to(device)
for p in model.parameters():
if p.requires_grad:
l1_norm_all += p.abs().sum()
#loss += 1e-5*l1_norm_all
f1.update(torch.argmax(output.softmax(dim=-1),1), target)
mcc.update(torch.argmax(output.softmax(dim=-1),1), target)
val_loss += loss.data.item() * data.size(0)
val_loss = val_loss/len(val_data_loader.dataset)
lr_scheduler.step(metrics=val_loss)
f1val = f1.compute()
mccval = mcc.compute()
print('Epoch_Val:{}, Val_Loss:{:.2f}, F1= {:.2f}, MCC= {:.4f}'
.format(epoch, val_loss, f1val, mccval ))
wandb.log({'acc_val' + '_fold_' + str(fold): f1val.item(), 'mcc_val' + '_fold_'+ str(fold): mccval.item(), 'loss_val' + '_fold_' + str(fold): val_loss})
if early_stopper.early_stop(val_loss):
break
if val_loss < min_val_loss:
print("The best is at epoch " + str(epoch))
torch.save(model.state_dict(), '/home/mdatres/sicura_SSL/result_pact/bestq' + '_fold_'+str(fold) + '.pth')
min_val_loss = val_loss
all_val_loss.append((min_val_loss, '/home/mdatres/sicura_SSL/result_pact/bestq' + '_fold_'+str(fold) + '.pth'))
#Testing phase
model = resnet50()
model.fc = torch.nn.Linear(model.fc.in_features, 2)
model.qconfig = my_qconfig
print("Quantization according to the observers: " + str(model.qconfig))
torch.quantization.prepare_qat(model, inplace=True)
data = SicuraPPDataset(path_to_ourdata, test_transforms)
def second_item(data):
return data[0]
all_val_loss.sort(key=second_item, reverse=True)
with open('/home/mdatres/sicura_SSL/result_pact/best_all.txt', 'w') as f:
f.write('The best choosen model is ' + all_val_loss[0][1])
model.load_state_dict(torch.load(all_val_loss[0][1]))
model.to("cpu")
model.eval()
torch.quantization.convert(model, inplace=True)
netpp = QResnet50pp(threshold=0.7,model=model)
netpp.to("cpu")
netpp.eval()
f1_test = torchmetrics.F1Score(task='binary').to("cpu")
mcc_test = torchmetrics.MatthewsCorrCoef(task = 'binary').to("cpu")
for d in data:
o = netpp(input=d, device = "cpu", k=1)
f1_test.update(torch.tensor([o]).to("cpu"), torch.tensor([d[1]]).to("cpu"))
mcc_test.update(torch.tensor([o]).to("cpu"), torch.tensor([d[1]]).to("cpu"))
f1_final = f1_test.compute()
mcc_final = mcc_test.compute()
print('Test Statistic'+': '+ 'F1_score: ' +str(f1_final.item())+ ' MCC_score: ' +str(mcc_final.item()))
save_test.append({"F1_score": str(f1_final.item()), "MCC_score": str(mcc_final.item())})
Thanks in advance,
mdatres