If you want QAT training to be much faster, you can make the following changes:

First, parallelize calculate_qparams for the PerChannel observers. Making the following change improved performance when calibrating (with observers enabled) by ~9x

```
def calculate_qparams(self):
min_val, max_val = self.min_vals, self.max_vals
if max_val is None or min_val is None:
warnings.warn(
"must run observer before calling calculate_qparams.\
Returning default scale and zero point "
)
return torch.tensor([1.0]), torch.tensor([0])
assert torch.all(min_val <= max_val), "min {} should be less than max {}".format(
min_val, max_val
)
if self.dtype == torch.qint8:
if self.reduce_range:
qmin, qmax = -64, 63
else:
qmin, qmax = -128, 127
else:
if self.reduce_range:
qmin, qmax = 0, 127
else:
qmin, qmax = 0, 255
min_val = torch.clamp(min_val, max=0.0)
max_val = torch.clamp(max_val, min=0.0)
# The check of max_val == min_val is removed -- in that case, I prefer taking scale = eps than 1.
if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
max_val = torch.max(-min_val, max_val)
scale = max_val / ((qmax - qmin) / 2)
scale = torch.clamp(scale, min=self.eps)
zero_point = torch.ones_like(scale) * math.ceil((qmin + qmax) / 2)
else:
scale = (max_val - min_val) / float(qmax - qmin)
scale = torch.clamp(scale, min=self.eps)
zero_point = qmin - torch.round(min_val / scale)
zero_point = torch.clamp(zero_point, min=qmin, max=qmax)
zero_point = zero_point.long()
# Note: this code keeps scale, zero_point on GPU. Only use this if you parallelize the
# implementation in fake quantize instead of using fake_quantize_per_channel_affine
# by following the next code block. Otherwise CPU is faster
# return scale.cpu(), zero_point.cpu()
return scale, zero_point
```

Second, use a parallelized version of fake quantization per channel (the C++ implementation of the operation iterates over every channel, which is slow). We can do this by changing FakeQuantize’s forward method to be the following. Note that you get almost identical results to the previous code, but values of weights that are close to a rounding boundary when quantizing (e.g. 67.5000) may be quantized to the other bin. Adding this decreases inference time by a factor of 2, and decreases calibration time by a factor of 1.5. Once making these two changes, inference time with the quantized model is the same as inference time with the baseline model.

```
def forward(self, X):
if self.observer_enabled:
self.observer(X.detach())
self.scale, self.zero_point = self.calculate_qparams()
if self.fake_quant_enabled:
if self.qscheme == torch.per_channel_symmetric or self.qscheme == torch.per_channel_affine:
new_shape = [1] * len(X.shape)
new_shape[self.ch_axis] = -1
self.scale = self.scale.view(new_shape)
self.zero_point = self.zero_point.view(new_shape)
X = X / self.scale + self.zero_point
X = torch.fake_quantize_per_tensor_affine(X, float(1.0),
int(0), self.quant_min,
self.quant_max)
X = (X - self.zero_point) * self.scale
else:
X = torch.fake_quantize_per_tensor_affine(X, float(self.scale),
int(self.zero_point), self.quant_min,
self.quant_max)
return X
```