class quantizeModel(object):
“”“docstring for quantizePytorchModel”""
def __init__(self):
super(quantizeModel, self).__init__()
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.train_loader, self.test_loader = get_imagenet()
self.quant()
def quant(self):
model = self.load_model()
model.eval()
self.print_size_of_model(model)
self.validate(model, "original_resnet18", self.test_loader)
model.fuse_model()
self.print_size_of_model(model)
self.quantize(model)
def load_model(self):
model = resnet18()
state_dict = torch.load("CIFAR10_resnet18.pth", map_location=self.device)
model.load_state_dict(state_dict)
model.to(self.device)
return model
def print_size_of_model(self, model):
torch.save(model.state_dict(), "temp.p")
print('Size (MB):', os.path.getsize("temp.p") / 1e6)
os.remove('temp.p')
def validate(self, model, name, data_loader):
with torch.no_grad():
correct = 0
total = 0
acc = 0
for data in data_loader:
images, labels = data
images, labels = images.to(self.device), labels.to(self.device)
output = model(images)
_, predicted = torch.max(output, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if total == 1024:
break
acc = round(100 * correct / total, 3)
print('{{"metric": "{}_val_accuracy", "value": {}%}}'.format(name, acc))
return acc
def quantize(self, model):
#model.qconfig = torch.quantization.default_qconfig
#model.qconfig = torch.quantization.default_per_channel_qconfig
model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.observer.MinMaxObserver.with_args(reduce_range=True),
weight=torch.quantization.observer.PerChannelMinMaxObserver.with_args(dtype=torch.qint8,
qscheme=torch.per_channel_affine))
pmodel = torch.quantization.prepare(model)
#calibration
self.validate(pmodel, "quntize_per_channel_resent18_train", self.train_loader)
qmodel = torch.quantization.convert(pmodel)
self.validate(qmodel, "quntize_per_chaannel_resent18_test", self.test_loader)
self.print_size_of_model(qmodel)
torch.jit.save(torch.jit.script(qmodel), "quantization_per_channel_model18.pth")