Hello everyone! I’m currently using TorchDynamo to generate the FX graph of a bert-base-uncased model. Specifically, I write a custom compile backend which only print the FX graph. I also employ the decomposition rule from inductor in order to decompose large operators to aten/prims operators. Here is the code I use
from transformers import BertTokenizer, BertModel
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp
from torch._functorch.aot_autograd import aot_module_simplified
def custom_backend(gm, inputs):
def _compiler(_gm, _):
_gm.graph.print_tabular()
return _gm.forward
return aot_module_simplified(gm,
inputs,
fw_compiler=_compiler,
decompositions=inductor_decomp)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
text = "Hello world!"
encoded_text = tokenizer(text, return_tensors="pt")
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
model_compiled = dynamo.optimize(backend=custom_backend)(model)
model_compiled(**encoded_text)
However, the FX graph output by the above code is a little less than satisfactory. The input list of the output
node is so long. Most of the input values are unnecessary for my case. Here is the output
node printed torch.fx.graph.Graph
:
So my question is, what causes such a long input list of output
node? Is there any method to manipulate the FX graph capturing process of TorchDynamo so that the output
node only contains input values I need? Any insight would be greatly appreciated!