Saving modified fx graph

I quantized resnet18 with pt2e. In the qunatized model I replaced few patterns with custom functions using subgraph_rewriter.replace_pattern and torch.fx.wrap for custom functions.
How to save the modified graph with custom functions ?
If I use torch.export and save the model, custom functions are replaced with standard operators. I won’t be able to work on it further at later point in time.

Code snippet is below. torchfxlib has the custom functions which is mainly used to fuse the q/dq b/w nodes into the same node.

model_to_quantize_wt = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1).to(“cpu”)
model_to_quantize = resnet.resnet18()
model_to_quantize.load_state_dict(model_to_quantize_wt.state_dict())
model_to_quantize.eval()

exported_model = torch.export.export_for_training(model_to_quantize, example_inputs).module()

from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
get_symmetric_quantization_config,
)

qconfig = get_quantization_config()

select a quantizer and configure

quantizer = XNNPACKQuantizer()
quantizer.set_global(qconfig)

Prepare for quantization

prepared_model = prepare_pt2e(exported_model, quantizer)

def calibrate(model, data_loader):
torch.ao.quantization.move_exported_model_to_eval(model)
with torch.no_grad():
for image, target in data_loader:
model(image)

run calibration on sample data

calibrate(prepared_model, data_loader_cal)
quantized_model = convert_pt2e(prepared_model, fold_quantize=True)

Transform the graph with custom functions

matched_pattern = subgraph_rewriter.replace_pattern(quantized_model, torchfxlib.pattern_qConv2dRelu, torchfxlib.replacement_qConv2dRelu)
matched_pattern = subgraph_rewriter.replace_pattern(quantized_model, torchfxlib.pattern_qConv2d, torchfxlib.replacement_qConv2d)
matched_pattern = subgraph_rewriter.replace_pattern(quantized_model, torchfxlib.pattern_qmax_pool2d, torchfxlib.replacement_qmax_pool2d)
matched_pattern = subgraph_rewriter.replace_pattern(quantized_model, torchfxlib.pattern_qadd, torchfxlib.replacement_qadd)
matched_pattern = subgraph_rewriter.replace_pattern(quantized_model, torchfxlib.pattern_qavg_pool2d, torchfxlib.replacement_qavg_pool2d)
matched_pattern = subgraph_rewriter.replace_pattern(quantized_model, torchfxlib.pattern_qflatten, torchfxlib.replacement_qflatten)
matched_pattern = subgraph_rewriter.replace_pattern(quantized_model, torchfxlib.pattern_qlinear, torchfxlib.replacement_qlinear)
matched_pattern = subgraph_rewriter.replace_pattern(quantized_model, torchfxlib.pattern_q, torchfxlib.replacement_q)
matched_pattern = subgraph_rewriter.replace_pattern(quantized_model, torchfxlib.pattern_dq, torchfxlib.replacement_dq)

Export the transformed graph and Save ExportedProgram

pt2e_quantized_model_file_path = saved_model_dir + “resnet18_quantized.pth”
quantized_ep = torch.export.export(quantized_model, example_inputs)
torch.export.save(quantized_ep, pt2e_quantized_model_file_path)