Hi all. I’ve been trying to wrangle my head around how to export quantized models, but judging by other posts on this forum, it seems like only caffe2 is supported at this point. However, I notice that torch fx sometimes produces graphs for very simple models that are not even exportable to caffe2. Here’s a somewhat long example, but I’ll highlight one case:
import torch
from torch import nn
import torch.quantization.quantize_fx as quantize_fx
from torch.fx import symbolic_trace
import io
import onnx
# Parameters
input_shape = (1,3,224,224)
# More parameters
WITH_ADD = False # buggy
#WITH_ADD = True
MAKE_EXPORTABLE = True
#MAKE_EXPORTABLE = False # buggy
#EXPORT_TO_ONNX = True # buggy
EXPORT_TO_ONNX = False
class MyModuleAdd(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear((x + x) + x).clamp(min=0.0, max=1.0)
# Simple module which is just a linear layer followed by a clamped activation
class MyModuleNoAdd(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x).clamp(min=0.0, max=1.0)
class PassThrough(nn.Module):
def forward(self, *args):
return args[0]
# This fx transform gets rid of torch.quantize_per_tensor inside the forward def
# and associated member variables to help with greater exportability
# torch.quantize_per_tensor should be called on the input outside of the model
def post_quantization_transform(m):
for field_name in dir(m):
if ('scale' in field_name or 'zero_point' in field_name) and isinstance(getattr(m, field_name), torch.Tensor):
tensor_value = getattr(m, field_name)
delattr(m, field_name)
setattr(m, field_name, tensor_value.item())
symbolic_traced = symbolic_trace(m)
count = 0
for node in symbolic_traced.graph.nodes:
if node.target == torch.quantize_per_tensor:
count += 1
with symbolic_traced.graph.inserting_after(node):
setattr(symbolic_traced, 'passthrough_{}'.format(count), PassThrough())
new_node = symbolic_traced.graph.call_module('passthrough_{}'.format(count), args=(node.args[0],))
node.replace_all_uses_with(new_node)
node._remove_from_list()
setattr(symbolic_traced, node.name, None)
symbolic_traced.recompile()
return symbolic_traced
module = MyModuleAdd() if WITH_ADD else MyModuleNoAdd()
# Do a standard fx quantization for this experiment; nothing fancy
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
module.eval()
model_prepared = quantize_fx.prepare_fx(module, qconfig_dict)
model_quantized = quantize_fx.convert_fx(model_prepared)
x = torch.randn(3,4)
if MAKE_EXPORTABLE:
model_final = post_quantization_transform(model_quantized)
x = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.quint8)
else:
model_final = model_quantized
outputs = model_final(x)
# Export to TorchScript
traced_model = torch.jit.trace(model_final, x)
buf = io.BytesIO()
torch.jit.save(traced_model, buf)
buf.seek(0)
reloaded_model = torch.jit.load(buf)
if EXPORT_TO_ONNX:
try:
ft = io.BytesIO()
torch.onnx.export(reloaded_model, x, ft, example_outputs=outputs, opset_version=13)
ft.seek(0)
print('Succeeded exporting model with opset version 13')
except Exception as e:
print('Failed to export TorchScript model to onnx,', e)
try:
ft = io.BytesIO()
torch.onnx.export(model_final, x, ft, opset_version=13)
ft.seek(0)
print('Succeeded exporting torch script model with opset version 13')
except Exception as e:
print('Failed to export model to onnx,', e)
# Export to caffe2
f = io.BytesIO()
torch.onnx.export(reloaded_model, x, f, example_outputs=outputs,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,)
f.seek(0)
onnx_model = onnx.load(f)
There are other issues that can be generated with this script but I’ve left the initial set of parameters set as they are to generate one specific bug. That is, this script does not include any adds, makes the model exportable, and does not export to onnx. Running the script produces the following error:
RuntimeError: false INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp":94, please report a bug to PyTorch. Unrecognized quantized operator while trying to compute q_scale for operator prim::Param
If I use the model with the adds, I can actually successfully export to caffe2 strangely. Any clue why that is? Thanks