Transform graph module of the exported program

Hello.

I’m trying to transform the graph module in the exported program after importing a pt2 file.

So, I’ve written the transformer class according to this. But, after the transformation, there is no way to update the exported program. I tried to create a new exported program with the transformed graph module. But, I couldn’t call the export method because it seems that imported exported program doesn’t have example inputs.

class DecomposeFakeQuant(torch.fx.Transformer):
    """
    Original:
        def f(x):
            return torch.fake_quantize_per_tensor_affine(x, s, zp, q_min, q_max)

    After pass:
        def f(x):
            x = qd.quantize_per_tensor(x, s, zp, q_min, q_max, dtype)
            x = qd.dequantize_per_tensor(x, s, zp, q_min, q_max, dtype)
            return x
    """
    def call_function(self, target, args, kwargs):
        if target != torch.ops.aten.fake_quantize_per_tensor_affine_cachemask.default:
            return super().call_function(target, args, kwargs)

        q_x = super().call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default, args, kwargs)
        args = (q_x, *args[1:])
        return super().call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default, args, kwargs)

exported_program = torch.export.load(input_pt2)
transformed_gm = DecomposeFakeQuant(exported_program.graph_module).transform()
# exported_program._graph_module = transformed_gm -> ERROR because graph_module is immutable
# transformed_ep = export(transformed_gm, exported_program.example_inputs) -> ERROR because exported_program doesn't have example_inputs

Is there any proper method for this case?