Do I really need two separate model definition for a quantized and an "unquantized" model?

Hi all,
I successfully quantized parts of a visual transformer model I found online.

    backend = 'qnnpack'
    model.qconfig = torch.quantization.get_default_qconfig(backend)
    torch.backends.quantized.engine = backend

    model = model.to(device='cpu')
    quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear},
                                                          dtype=torch.qint8)
    torch.save(quantized_model.state_dict(), ...)

Note: I do not save my model in the torchscript format, because as far as I can tell, it would require me to significantly rewrite the repo code (see the error below).

torch.jit.frontend.NotSupportedError: Compiled functions can’t take variable number of arguments or use keyword-only arguments with defaults:

One of the layers of the model contains the following line:

q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)

By default, x, self.q_proj and self.q_bias are torch.float32 variables. After quantization x is still torch.float32 and self.q_proj, self.q_bias become torch.qint8 variables. If I try to run inference on the quantized model this line fails with

TypeError: linear(): argument ‘weight’ must be Tensor, not method

If I modify it like so

q = F.linear(input=x, weight=self.q_proj.weight(), bias=self.q_bias)

it fails with

q = F.linear(input=x, weight=self.q_proj.weight(), bias=self.q_bias)
RuntimeError: self and mat2 must have the same dtype, but got Float and QInt8

And if I modify it further like so

q = F.linear(input=x, weight=torch.dequantize(self.q_proj.weight()), bias=torch.dequantize(self.q_bias))

The code passes through.

Here’s the problem: as it stands I need two model definitions: one for quantized models with the added changes and one for “unquantized” model (there are other modifications like this one whenever the behaviour of float32 tensor differs from qint8 tensor). Here are some scenarios I’m considering:

  1. I have to serialize the model. This is not ideal and it requires significant changes to the original code. I’m also not sure if it would resolve the problem.
  2. Keep two separate model definitions. Very annoying to maintain.
  3. Configure quantization script - this is what I’m looking for, but I don’t if it’s possible.
  4. The original code does not use torch appropriately - I need to rewrite failing parts so that the behaviour is identical no matter the dtype of the tensor.

so here is the new tool: Quantization — PyTorch main documentation can you use this one instead?

some example code for dynamic quantization can be found here: pytorch/test/quantization/pt2e/test_xnnpack_quantizer.py at main · pytorch/pytorch · GitHub