Hello.
I’m trying to transform the graph module in the exported program after importing a pt2 file.
So, I’ve written the transformer class according to this. But, after the transformation, there is no way to update the exported program. I tried to create a new exported program with the transformed graph module. But, I couldn’t call the export
method because it seems that imported exported program doesn’t have example inputs.
class DecomposeFakeQuant(torch.fx.Transformer):
"""
Original:
def f(x):
return torch.fake_quantize_per_tensor_affine(x, s, zp, q_min, q_max)
After pass:
def f(x):
x = qd.quantize_per_tensor(x, s, zp, q_min, q_max, dtype)
x = qd.dequantize_per_tensor(x, s, zp, q_min, q_max, dtype)
return x
"""
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.fake_quantize_per_tensor_affine_cachemask.default:
return super().call_function(target, args, kwargs)
q_x = super().call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default, args, kwargs)
args = (q_x, *args[1:])
return super().call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default, args, kwargs)
exported_program = torch.export.load(input_pt2)
transformed_gm = DecomposeFakeQuant(exported_program.graph_module).transform()
# exported_program._graph_module = transformed_gm -> ERROR because graph_module is immutable
# transformed_ep = export(transformed_gm, exported_program.example_inputs) -> ERROR because exported_program doesn't have example_inputs
Is there any proper method for this case?