Histogram Calibration taking incredibly long time

Hello, I am currently experimenting with Quantized Aware Training for common YOLO models. I suspect a solid performance boost from using histogram calibration instead of min/max. However, after testing it seems that Histogram will take several magnitudes more of time to calibrate than min/max. Though I expected a longer time for histogram, it seems a bit absurd at the moment (min/max takes about 30 mins with my calibration set, Histogram is estimated to by >100 hours). I have initialized the model’s tensors with the QuantDescriptor

quant_desc_input = QuantDescriptor(calib_method=opt.qat_calibrator)
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
quant_modules.initialize() 

And I am following the standard TensorRT/PyTorch tutorial for the calibration of the model

def collect_stats(model, data_loader, num_batches, device):
    """Feed data to the network and collect statistics"""
    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    # Feed data to the network for collecting stats
    for i, (img, _, _, _) in tqdm(enumerate(data_loader), total=num_batches):
        img = img.to(device, non_blocking=True).float() / 255.0
        model(img)
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()


def compute_amax(model, **kwargs):
    # Load calib result
    print('Computing amax')
    max_count = 0
    hist_count = 0
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    max_count += 1
                    module.load_calib_amax()
                else:
                    print('Using histogram')
                    hist_count += 1
                    module.load_calib_amax(**kwargs)
            print(F"{name:40}: {module}")

The issue seems to happen during the compute_stats function, where each core of the CPU gets taken up for every batch and performs incredibly slow. Any help would be appreciated.

Hi @Quinn_Graehling off the top of my head I don’t see a reason why histogram calibration should take that much longer. Can you share a minimal script to reproduce?

@jcaip I am currently following this tutorial from PyTorch:

https://pytorch.org/TensorRT/_notebooks/vgg-qat.html

However, this tutorial does not actually account for histogram calibration as quant_initialize() at [7] initializes all layers with a MaxCalibrator, so regardless of the kwargs you pass in compute_amax() it will always see and perform max calibration. In order to tie HistogramCalibrator to the layers, we did the following in addition to the quant_initialize()

quant_desc_input = QuantDescriptor(calib_method=opt.qat_calibrator)
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
quant_modules.initialize()    

This does tie HistogramCalibrators to the activations inputs. This is the only major difference we have made, besides using a YOLO based architecture instead of the VGG network. I believe if you follow the notebook in the PyTorch link it with the addition of our steps it will produce the same results. I will try this on my end as well to make sure.

Ah, I believe the quantization library used in that tutorial is pytorch-quantization master documentation

and not the native pytorch quantization workflow. You may have better luck asking there.