Quantize_fx produces graphs not exportable to caffe2 (or onnx)

Hi all. I’ve been trying to wrangle my head around how to export quantized models, but judging by other posts on this forum, it seems like only caffe2 is supported at this point. However, I notice that torch fx sometimes produces graphs for very simple models that are not even exportable to caffe2. Here’s a somewhat long example, but I’ll highlight one case:

import torch
from torch import nn
import torch.quantization.quantize_fx as quantize_fx
from torch.fx import symbolic_trace
import io
import onnx

# Parameters
input_shape = (1,3,224,224)
    
# More parameters
WITH_ADD = False # buggy
#WITH_ADD = True
MAKE_EXPORTABLE = True
#MAKE_EXPORTABLE = False # buggy
#EXPORT_TO_ONNX = True # buggy
EXPORT_TO_ONNX = False

class MyModuleAdd(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 5)
        
    def forward(self, x):
        return self.linear((x + x) + x).clamp(min=0.0, max=1.0)

# Simple module which is just a linear layer followed by a clamped activation
class MyModuleNoAdd(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x).clamp(min=0.0, max=1.0)

class PassThrough(nn.Module):
    def forward(self, *args):
        return args[0]

# This fx transform gets rid of torch.quantize_per_tensor inside the forward def
# and associated member variables to help with greater exportability
# torch.quantize_per_tensor should be called on the input outside of the model
def post_quantization_transform(m):
    for field_name in dir(m):
        if ('scale' in field_name or 'zero_point' in field_name) and isinstance(getattr(m, field_name), torch.Tensor):
            tensor_value = getattr(m, field_name)
            delattr(m, field_name)
            setattr(m, field_name, tensor_value.item())
    
    symbolic_traced = symbolic_trace(m)

    count = 0
    for node in symbolic_traced.graph.nodes:
        if node.target == torch.quantize_per_tensor:
            count += 1
            with symbolic_traced.graph.inserting_after(node):
                setattr(symbolic_traced, 'passthrough_{}'.format(count), PassThrough())
                new_node = symbolic_traced.graph.call_module('passthrough_{}'.format(count), args=(node.args[0],))
            node.replace_all_uses_with(new_node)
            node._remove_from_list()
            setattr(symbolic_traced, node.name, None)

    symbolic_traced.recompile()
    return symbolic_traced

module = MyModuleAdd() if WITH_ADD else MyModuleNoAdd()

# Do a standard fx quantization for this experiment; nothing fancy
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
module.eval()
model_prepared = quantize_fx.prepare_fx(module, qconfig_dict)
model_quantized = quantize_fx.convert_fx(model_prepared)
x = torch.randn(3,4)

if MAKE_EXPORTABLE:
    model_final = post_quantization_transform(model_quantized)
    x = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.quint8)
else:
    model_final = model_quantized

outputs = model_final(x)

# Export to TorchScript
traced_model = torch.jit.trace(model_final, x)
buf = io.BytesIO()
torch.jit.save(traced_model, buf)
buf.seek(0)
reloaded_model = torch.jit.load(buf)

if EXPORT_TO_ONNX:
    try:
        ft = io.BytesIO()
        torch.onnx.export(reloaded_model, x, ft, example_outputs=outputs, opset_version=13)
        ft.seek(0)
        print('Succeeded exporting model with opset version 13')
    except Exception as e:
        print('Failed to export TorchScript model to onnx,', e)
    try:
        ft = io.BytesIO()
        torch.onnx.export(model_final, x, ft, opset_version=13)
        ft.seek(0)
        print('Succeeded exporting torch script model with opset version 13')
    except Exception as e:
        print('Failed to export model to onnx,', e)

# Export to caffe2
f = io.BytesIO()
torch.onnx.export(reloaded_model, x, f, example_outputs=outputs,
                operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,)
f.seek(0)
onnx_model = onnx.load(f)

There are other issues that can be generated with this script but I’ve left the initial set of parameters set as they are to generate one specific bug. That is, this script does not include any adds, makes the model exportable, and does not export to onnx. Running the script produces the following error:

RuntimeError: false INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp":94, please report a bug to PyTorch. Unrecognized quantized operator while trying to compute q_scale for operator prim::Param

If I use the model with the adds, I can actually successfully export to caffe2 strangely. Any clue why that is? Thanks

For anyone interested, I’ve also summarized the other set of issues that can be generated by the script in this table:

Model type Add type Exportability Export backend Error Notes
Torch Script No add Made exportable Onnx 13 Exporting the operator linear to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub
Torch Script With add Made exportable Onnx 13 add() takes from 3 to 4 positional arguments but 5 were given This is a quantized add. Inspecting the generated code clearly shows that 4 arguments are passed: add_1 = ops.quantized.add(passthrough_1, passthrough_1, 1., 0)… The arguments are (tensor_a, tensor_b, scale, zero_point).
Torch Script No add; with add Not made exportable Onnx 13 Exporting the operator quantize_per_tensor to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.
Pytorch No add; with add Made exportable; not made exportable Onnx 13 ‘torch.dtype’ object has no attribute ‘detach’ Maybe exporting thinks some dtype member variable is a parameter and is trying to detach it. However, I found no such dtype member variable within the model.
Torch Script No add Made exportable Caffe2 RuntimeError: false INTERNAL ASSERT FAILED at /pytorch/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp:94, please report a bug to PyTorch. Unrecognized quantized operator while trying to compute q_scale for operator prim::Param
Torch Script With add; No add Not made exportable Caffe2 RuntimeError: Expected node type ‘onnx::Constant’ for argument ‘zero_point’ of node ‘quantize_per_tensor’, got ‘prim::Param’.
Torch Script With add Made exportable Caffe2 No error Only case without an error
1 Like