Convert back to Unquantized model

Hello. I have a question about convert in torch.quantization.

For a model like this,

  (module): LeNet(
    (l1): Linear(in_features=784, out_features=10, bias=True)
    (relu1): ReLU(inplace=True)
  )

After QAT and convert, I got

  (module): LeNet(
    (l1): QuantizedLinear(in_features=784, out_features=10, scale=0.5196203589439392, zero_point=78, qscheme=torch.per_channel_affine)
    (relu1): QuantizedReLU(inplace=True)
  )

But, I’m looking for a way to do an evaluation on CUDA, and in that sense, I need to convert it back to the pre-QAT model yet with ‘quantized FP32’ weights and perhaps custom forward_hook to perform activation quantization. Can someone advise the best way to achieve this? In my understanding, these are the steps but like to ensure I don’t reinvent the wheel here.

  • write a new converter to get the pre-QAT model architecture and load quantized weight (but, in FP32).
  • add forward_prehook that does quantization per scale/zero_point from activation_post_process
    (should it be forward_prehook or forward_posthook??)

Any suggestions would be appreciated!

Hi @thyeros, you can use the QAT model after prepare and before convert to evaluate in fp32 emulating int8. It will model the quantized numerics in fp32 with the FakeQuantize modules, and it works on CUDA. Here is an example from torchvision: https://github.com/pytorch/vision/blob/master/references/classification/train_quantization.py#L134

oh, I see. Just using the QAT model in eval will run just find on CUDA, with all int8 emulation effects. So, there is literally nothing special to do. Is that right?

In terms of eval speed on CUDA, then would it be still a good idea to drop observers? or just disabling them would be sufficient?

Thanks!

In terms of eval speed on CUDA, then would it be still a good idea to drop observers? or just disabling them would be sufficient?

You can use model.apply(torch.quantization.disable_observer) and model.apply(torch.quantization.enable_observer) to toggle them (example: https://github.com/pytorch/vision/blob/master/references/classification/train_quantization.py#L128).

Already disabling them. Great, thanks! :+1:

Hello, is anyone implement this convert back to unquantized model function? I need to convert the QAT model to an SOC chip’s model. Or I should not use pytorch’s QAT for this purpose?

can you clarify the use case here? The OP wanted to run a quantized model on cuda which is why vasiliy recommended using the QAT pre-convert model without observers since that mimics quantized numerics but runs on cuda.

If you can clarify what you are doing we can maybe help but in general we don’t have a way to unconvert the model though you could write something that manually extracts whatever data you need.

Thank you for your reply. Here’s what I think the pipeline could be:

  1. Use QAT to train a quantization-friendly model.
  2. Convert this model to a standard floating-point model.
  3. Use the SOC’s quantization tool to convert the floating-point model into a specific model type that can run on the SOC.

The reason is that the SOC doesn’t support QAT and only accepts floating-point models as input. So, I think I might need to manually write something to extract data here, right?

Does it accept the post qat, pre convert model? That has floating point weights until convert is run

I am utilizing the ESP32 quantization tool, which requires the model to be converted to ONNX format first. During the process of converting the post-QAT, pre convert model to ONNX, I encounter an ‘UnsupportedOperatorError’. Is this a common issue? Thank you for your patience and response.

import torch.onnx
model_prepared.apply(torch.ao.quantization.disable_fake_quant)
model_prepared.apply(torch.quantization.disable_observer) 
torch.onnx.export(model_prepared,               # model being run
                  input_fp32,                         # model input (or a tuple for multiple inputs)
                  "temp.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

"UnsupportedOperatorError: Exporting the operator 'aten::fused_moving_avg_obs_fake_quant' to ONNX opset version 10 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues."

Use torch.fx replace torch.ao.nn.quantized.modules.conv.Conv2d and torch.ao.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d layers with torch.nn.Conv2d layers in the quantized model. I got a model can output as onnx model.

I’m not very sure this is correct but the result sames OK.

def unquant_model(model: torch.nn.Module, inplace=False) -> torch.nn.Module:

    if not inplace:
        model = copy.deepcopy(model)
    fx_model = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())
    new_graph = copy.deepcopy(fx_model.graph)
    
    for node in new_graph.nodes:
        # Delete first few layers and last few layers of quant module
        if node.target == "base_net_0_0_input_scale_0":
            node.next.next.next.replace_input_with(node.next.next.next.args[0],node.next.next.args[0])
            node.next.next.replace_all_uses_with(node.next.next)
            new_graph.erase_node(node.next.next)
            node.next.replace_all_uses_with(node.next)
            new_graph.erase_node(node.next)
            node.replace_all_uses_with(node)
            new_graph.erase_node(node)

        if node.target == "dequantize":
            node.replace_all_uses_with(node.args[0])
            new_graph.erase_node(node)

        if node.target not in modules:
            continue

        # replce quant conv with normal conv
        if type(modules[node.target]) is torch.ao.nn.quantized.modules.conv.Conv2d or
           type(modules[node.target]) is torch.ao.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d:
            norm_conv = unquant_conv(modules[node.target])
            replace_node_module(node, modules, norm_conv)

    return fx.GraphModule(fx_model, new_graph)

yeah makes sense, you’re losing the learned scales and zero points from quantization though

I have noticed that there are two sets of scales and zero_points in torch.ao.nn.quantized’s conv2d.

  1. One set is in QuantizedConv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), scale=0.09829162061214447, zero_point=48, padding=(1, 1)). I believe this scale and zero_point are for the output of this layer, so if I use this model as a float model, this scale becomes irrelevant.
  2. The other scale is found in conv.weight(), showing scale=tensor([0.0489, 0.0379, 0.0250, 0.0418, 0.0370, 0.0329, 0.0264, 0.0320, 0.0184, 0.0291, 0.0283, 0.0276, 0.0277, 0.0296, 0.0229, 0.0343], dtype=torch.float64), which applies to each channel’s weights. I think these scales have already been applied when converting the QAT model to a quantized model.

Therefore, I haven’t lost any scales or zero points, am I correct?

the output scale and zero point are learned during QAT based on the seen activation ranges. That’s what you’d be losing. You can use the model as a float model but at some point if you want to quantize the model, you’d want those. The weight qparams can be determined from the weight itself so their loss is irrelevant.

1 Like

Thank you once again. Your explanation has greatly enhanced my understanding.