Does `torch.func.functionalize` support `fx.GraphModule` from `dynamo.export` w/ aten graph?

Background: I wanted to remove mutation from my dynamo exported fx.graph. Part of the workflow looks like this

gm, _ = torch._dynamo.export(model, dummy_input, aten_graph=True)
gm = proxy_tensor.make_fx(func.functionalize(gm))(dummy_input)

It worked for many simple cases, but I ran into the below error with shufflenet from torchvision.

  File "<eval_with_key>.4", line 15, in forward
  File "/home/bowbao/pytorch/torch/_ops.py", line 398, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: false INTERNAL ASSERT FAILED at "/home/bowbao/pytorch/build/aten/src/ATen/RegisterFunctionalization_2.cpp":7718, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.

The full repro script:

import torchvision
import torch
import torch._dynamo
from torch import func
from torch.fx.experimental import proxy_tensor

model = torchvision.models.shufflenet_v2_x0_5(pretrained=False)
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)

gm, _ = torch._dynamo.export(model, dummy_input, aten_graph=True)
gm = proxy_tensor.make_fx(func.functionalize(gm))(dummy_input)

Dug into this a little. Some context:

(1) that model is mutating some of its buffers / params as part of the model forward

(2) functorch.functionalize() has a more limited contract: it will remove mutations on graph inputs and intermediates, but it will not remove any mutations done to captured variables or global state. When you’re functionalizing a model.forward() call, buffers/parameters count as non-local state (they weren’t lifted to be inputs of the function that we’re functionalizing over)

(3) Some time soon, we’re going to add an API to aot autograd that will probably do what you want - a way to take a function / model and return a functionalized graph, that also handles other stuff for you like flattening input pytrees, and lifting module state into graph inputs.

Side note: export will likely perform functionalization automatically soon (this doesn’t happen today though).

1 Like

From what I saw, you can make it work with current pytorch by passing a aten decomposition table to make_fx so that the make_fx will work on decomposed graph:

from torch._decomp import core_aten_decompositions

def functionalized_export(func, *args, **kwargs):
    graph, _= torch._dynamo.export(func, *args, aten_graph=True, **kwargs)
    graph = make_fx(
        functionalize(interpret(graph)),
        decomposition_table=core_aten_decompositions()
    )(*args, **kwargs)
    return graph