How to manipulate the input of `output` node of FX Graph generated by torchdynamo?

Hello everyone! I’m currently using TorchDynamo to generate the FX graph of a bert-base-uncased model. Specifically, I write a custom compile backend which only print the FX graph. I also employ the decomposition rule from inductor in order to decompose large operators to aten/prims operators. Here is the code I use

from transformers import BertTokenizer, BertModel
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp
from torch._functorch.aot_autograd import aot_module_simplified


def custom_backend(gm, inputs):

    def _compiler(_gm, _):
        _gm.graph.print_tabular()
        return _gm.forward

    return aot_module_simplified(gm,
                                 inputs,
                                 fw_compiler=_compiler,
                                 decompositions=inductor_decomp)


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
text = "Hello world!"
encoded_text = tokenizer(text, return_tensors="pt")

model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
model_compiled = dynamo.optimize(backend=custom_backend)(model)

model_compiled(**encoded_text)

However, the FX graph output by the above code is a little less than satisfactory. The input list of the output node is so long. Most of the input values are unnecessary for my case. Here is the output node printed torch.fx.graph.Graph:

So my question is, what causes such a long input list of output node? Is there any method to manipulate the FX graph capturing process of TorchDynamo so that the output node only contains input values I need? Any insight would be greatly appreciated! :slight_smile:

If you only care about inference, you should run your compiled model under no_grad:

with torch.no_grad():
model_compiled(**encoded_text)

Most of those outputs in the forward graph correspond to “activations”, which torch.compile will then save for backward.

Since you are not running under no_grad, torch.compile (pessimistically) assumes that you might run .backward later, and so it needs to make sure when you create a forward graph that any potential activations are returned in the the forward, in case the user uses autograd.

Thank you for your reply! torch.no_grad works like a charm!