Hello everyone! I’m currently using TorchDynamo to generate the FX graph of a bert-base-uncased model. Specifically, I write a custom compile backend which only print the FX graph. I also employ the decomposition rule from inductor in order to decompose large operators to aten/prims operators. Here is the code I use
from transformers import BertTokenizer, BertModel import torch._dynamo as dynamo from torch._inductor.decomposition import decompositions as inductor_decomp from torch._functorch.aot_autograd import aot_module_simplified def custom_backend(gm, inputs): def _compiler(_gm, _): _gm.graph.print_tabular() return _gm.forward return aot_module_simplified(gm, inputs, fw_compiler=_compiler, decompositions=inductor_decomp) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") text = "Hello world!" encoded_text = tokenizer(text, return_tensors="pt") model = BertModel.from_pretrained("bert-base-uncased") model.eval() model_compiled = dynamo.optimize(backend=custom_backend)(model) model_compiled(**encoded_text)
However, the FX graph output by the above code is a little less than satisfactory. The input list of the
output node is so long. Most of the input values are unnecessary for my case. Here is the
output node printed
So my question is, what causes such a long input list of
output node? Is there any method to manipulate the FX graph capturing process of TorchDynamo so that the
output node only contains input values I need? Any insight would be greatly appreciated!