Quantization aware training, extremely slow on GPU

Hey all,

I’ve been experimenting with quantization aware training using pytorch 1.3.
I managed to adapt my model as demonstrated in the tutorial.
The documenation mentions that fake quantization is possible on GPU, however I notice that it is extremely slow.
Monitoring nvidia-smi shows that I only use 7% of the GPU, while it is close to 100% when using the non-qat adapted model.
Is this expected or am I doing something wrong?

1 Like

I would assume this is expected, since the FakeQuantize uses some additional operations on the tensor values to fake the quantization.

PyTorch 1.3 doesn’t provide quantized operator implementations on CUDA yet - this is direction of future work. Move the model to CPU in order to test the quantized functionality.
Quantization-aware training (through FakeQuantize ) supports both CPU and CUDA.

That should not be the case. fake_quantize is supported on the GPU. For more insight, can you compare the run-time per batch with and without quantization aware training on GPU?
Also, can you provide the quantization qconfig that you used for quantization aware training?

I believe the issue is due to the insertion of observers. I’m currently running MobileNetV2 with QAT, but only evaluating with calibration batches (no actual training), and then performing inference (after freezing the observers).

When using observer=torch.quantization.MinMaxObserver, before running qat_model.apply(torch.quantization.disable_observer), the time per batch is 2.95s, and afterwards, it is 1.51s, on cpu. On a GPU, the numbers are 3.71s and 0.38s. When the observers are not enabled, there is a clear speedup on GPUs, but there’s actually a performance degradation with observers on CPU.

When using the HistogramObserver, the performance on GPU (just for calibrating, not training) is 8x worse.

It appears that the issue is in pytorch’s method for computing per_channel quantization parameters – instead of parallelizing the computation, it runs in a loop per layer. Since this is sequential, it bottlenecks the performance. For PerChannelMinMaxObserver (and the moving average version) it should be pretty easy to modify the code to run in parallel.

If you want QAT training to be much faster, you can make the following changes:

First, parallelize calculate_qparams for the PerChannel observers. Making the following change improved performance when calibrating (with observers enabled) by ~9x

    def calculate_qparams(self):
        min_val, max_val = self.min_vals, self.max_vals
        if max_val is None or min_val is None:
            warnings.warn(
                "must run observer before calling calculate_qparams.\
                                    Returning default scale and zero point "
            )
            return torch.tensor([1.0]), torch.tensor([0])

        assert torch.all(min_val <= max_val), "min {} should be less than max {}".format(
            min_val, max_val
        )

        if self.dtype == torch.qint8:
            if self.reduce_range:
                qmin, qmax = -64, 63
            else:
                qmin, qmax = -128, 127
        else:
            if self.reduce_range:
                qmin, qmax = 0, 127
            else:
                qmin, qmax = 0, 255
        min_val = torch.clamp(min_val, max=0.0)
        max_val = torch.clamp(max_val, min=0.0)
        # The check of max_val == min_val is removed -- in that case, I prefer taking scale = eps than 1.
        if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
            max_val = torch.max(-min_val, max_val)
            scale = max_val / ((qmax - qmin) / 2)
            scale = torch.clamp(scale, min=self.eps)
            zero_point = torch.ones_like(scale) * math.ceil((qmin + qmax) / 2)
        else:
            scale = (max_val - min_val) / float(qmax - qmin)
            scale = torch.clamp(scale, min=self.eps)
            zero_point = qmin - torch.round(min_val / scale)
            zero_point = torch.clamp(zero_point, min=qmin, max=qmax)

        zero_point = zero_point.long()

        # Note: this code keeps scale, zero_point on GPU. Only use this if you parallelize the 
        # implementation in fake quantize instead of using fake_quantize_per_channel_affine
        # by following the next code block. Otherwise CPU is faster
        # return scale.cpu(), zero_point.cpu()
        return scale, zero_point

Second, use a parallelized version of fake quantization per channel (the C++ implementation of the operation iterates over every channel, which is slow). We can do this by changing FakeQuantize’s forward method to be the following. Note that you get almost identical results to the previous code, but values of weights that are close to a rounding boundary when quantizing (e.g. 67.5000) may be quantized to the other bin. Adding this decreases inference time by a factor of 2, and decreases calibration time by a factor of 1.5. Once making these two changes, inference time with the quantized model is the same as inference time with the baseline model.

    def forward(self, X):
        if self.observer_enabled:
            self.observer(X.detach())
            self.scale, self.zero_point = self.calculate_qparams()
        if self.fake_quant_enabled:
            if self.qscheme == torch.per_channel_symmetric or self.qscheme == torch.per_channel_affine:
                new_shape = [1] * len(X.shape)
                new_shape[self.ch_axis] = -1
                self.scale = self.scale.view(new_shape)
                self.zero_point = self.zero_point.view(new_shape)
                X = X / self.scale + self.zero_point
                X = torch.fake_quantize_per_tensor_affine(X, float(1.0),
                                                          int(0), self.quant_min,
                                                          self.quant_max)
                X = (X - self.zero_point) * self.scale
            else:
                X = torch.fake_quantize_per_tensor_affine(X, float(self.scale),
                                                          int(self.zero_point), self.quant_min,
                                                          self.quant_max)
        return X

Good obserations. Two issues have been created in github to track the suggestions here:
https://github.com/pytorch/pytorch/issues/30348 -> Speed up calc q params in observers.
https://github.com/pytorch/pytorch/issues/30349 -> Speed up per channel fake-quant

For fake-quant it will be better to speed up the C++ implementation to parallelize it.

1 Like

For fake-quant, parallelization speed up is dependent on the tensor size, with slow downs of 30-40% for smaller kernels. Since per-channel quant is typically done for weights, it is not desirable to have the slow-down for smaller tensors.

PR has been put up for speeding up calculation of q params in observers: https://github.com/pytorch/pytorch/pull/30485