The accuracy after int8 is higher than before quantization

Modify the amount of calibration data, the model accuracy after int8 quantization is actually higher than the original model accuracy

def get_imagenet(dataset_dir=’…/dataset/CIFAR10’, batch_size=32):

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    # transforms.Resize(256),
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

train_dataset = datasets.CIFAR10(root=dataset_dir, train=True, transform=train_transform, download=True)
test_dataset = datasets.CIFAR10(root=dataset_dir, train=False, transform=test_transform, download=True)

trainloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=NUM_WORKERS,
                                          pin_memory=True, shuffle=False)
testloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=NUM_WORKERS,
                                         pin_memory=True, shuffle=False)
return trainloader, testloader

class quantizeModel(object):

def __init__(self):
    super(quantizeModel, self).__init__()
    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    self.train_loader, self.test_loader = get_imagenet()
    self.quant()

def quant(self):

    model = self.load_model()
    model.eval()
    self.print_size_of_model(model)
    self.validate(model, "original_resnet18", self.test_loader)
    model.fuse_model()

    self.print_size_of_model(model)
    self.quantize(model)

def load_model(self):
    model = resnet18()
    state_dict = torch.load("CIFAR10_resnet18.pth", map_location=self.device)
    model.load_state_dict(state_dict)
    model.to(self.device)
    return model

def print_size_of_model(self, model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p") / 1e6)
    os.remove('temp.p')

def validate(self, model, name, data_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        acc = 0
        for data in data_loader:
            images, labels = data
            images, labels = images.to(self.device), labels.to(self.device)
            output = model(images)

            _, predicted = torch.max(output, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            if total == 1024:  #calibration data
                break

        acc = round(100 * correct / total, 3)
        print('{{"metric": "{}_val_accuracy", "value": {}%}}'.format(name, acc))
        return acc

def quantize(self, model):
    #model.qconfig = torch.quantization.default_qconfig
    #model.qconfig = torch.quantization.default_per_channel_qconfig

    model.qconfig = torch.quantization.QConfig(
        activation=torch.quantization.observer.MinMaxObserver.with_args(reduce_range=True),
        weight=torch.quantization.observer.PerChannelMinMaxObserver.with_args(dtype=torch.qint8,
                                                                              qscheme=torch.per_channel_affine))
    pmodel = torch.quantization.prepare(model)

    #calibration
    self.validate(pmodel, "quntize_per_channel_resent18_train", self.train_loader)
    qmodel = torch.quantization.convert(pmodel)

    self.validate(qmodel, "quntize_per_chaannel_resent18_test", self.test_loader)
    self.print_size_of_model(qmodel)

    torch.jit.save(torch.jit.script(qmodel), "quantization_per_channel_model18.pth")

Original model accuracy:71.76%

First quantification:batch_size:32 calibration data: 2048
Quantified model accuracy:71.51%

Second quantification:batch_size:32 calibration data: 1024
Quantified model accuracy:71.85%

Why the accuracy becomes higher after quantization?

In addition, I think that the total number of calibration data remains unchanged, the maximum and minimum range of activation should be fixed, and the quantization accuracy should also be fixed. However, it is found that the total number of calibration data remains unchanged, and if batch_size is modified, the accuracy after quantization will change. What is the reason?

Is there any randomness in which specific dataset slice is getting used for calibration? Can you reproduce the accuracy changes if you set torch.manual_seed(0)?

torch.manual_seed(191009)

train_dataset = datasets.CIFAR10(root=dataset_dir, train=True, transform=train_transform, download=True)
test_dataset = datasets.CIFAR10(root=dataset_dir, train=False, transform=test_transform, download=True)

trainloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=NUM_WORKERS,
                                          pin_memory=True, shuffle=False)
testloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=NUM_WORKERS,
                                         pin_memory=True, shuffle=False)

if batch_size is modified, the accuracy after quantization will change,no modification, the accuracy rate will not change.
Why the accuracy becomes higher after quantization?

We don’t expect accuracy to increase due to quantization, this is likely random variation. To test this theory, you could run evaluation on various slices of data unseen in training. I would expect the mean difference of accuracy on a set of slices would be a slight drop for the quantized model compared to the floating point model.

If MinMax observers are used, we do not expect the ordering or batch size of the calibration data to matter, as long as the same dataset gets seen via calibration. One thing to look into would be whether the way the evaluation score is measured depends on batch size, and if you are feeding exactly the same set of images through.

Thank you very much for your reply

model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.observer.MinMaxObserver.with_args(dtype=torch.quint8,
qscheme=torch.per_channel_affine,
reduce_range=True),
weight=torch.quantization.observer.PerChannelMinMaxObserver.with_args(dtype=torch.qint8,
qscheme=torch.per_channel_affine,
reduce_range=False))

The total number of calibration data sets remains unchanged, and different batch_sizes (8, 16, 32, 64) are tested, and the quantized accuracy rate will still slightly fluctuate.
batch_size has little effect on the quantization result, I can choose a group with the highest accuracy as the final quantization result.

1 Like