Hello, I am currently experimenting with Quantized Aware Training for common YOLO models. I suspect a solid performance boost from using histogram calibration instead of min/max. However, after testing it seems that Histogram will take several magnitudes more of time to calibrate than min/max. Though I expected a longer time for histogram, it seems a bit absurd at the moment (min/max takes about 30 mins with my calibration set, Histogram is estimated to by >100 hours). I have initialized the model’s tensors with the QuantDescriptor
quant_desc_input = QuantDescriptor(calib_method=opt.qat_calibrator)
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
quant_modules.initialize()
And I am following the standard TensorRT/PyTorch tutorial for the calibration of the model
def collect_stats(model, data_loader, num_batches, device):
"""Feed data to the network and collect statistics"""
# Enable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()
# Feed data to the network for collecting stats
for i, (img, _, _, _) in tqdm(enumerate(data_loader), total=num_batches):
img = img.to(device, non_blocking=True).float() / 255.0
model(img)
if i >= num_batches:
break
# Disable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.enable_quant()
module.disable_calib()
else:
module.enable()
def compute_amax(model, **kwargs):
# Load calib result
print('Computing amax')
max_count = 0
hist_count = 0
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
max_count += 1
module.load_calib_amax()
else:
print('Using histogram')
hist_count += 1
module.load_calib_amax(**kwargs)
print(F"{name:40}: {module}")
The issue seems to happen during the compute_stats
function, where each core of the CPU gets taken up for every batch and performs incredibly slow. Any help would be appreciated.