It worked for many simple cases, but I ran into the below error with shufflenet from torchvision.
File "<eval_with_key>.4", line 15, in forward
File "/home/bowbao/pytorch/torch/_ops.py", line 398, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: false INTERNAL ASSERT FAILED at "/home/bowbao/pytorch/build/aten/src/ATen/RegisterFunctionalization_2.cpp":7718, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.
(1) that model is mutating some of its buffers / params as part of the model forward
(2) functorch.functionalize() has a more limited contract: it will remove mutations on graph inputs and intermediates, but it will not remove any mutations done to captured variables or global state. When you’re functionalizing a model.forward() call, buffers/parameters count as non-local state (they weren’t lifted to be inputs of the function that we’re functionalizing over)
(3) Some time soon, we’re going to add an API to aot autograd that will probably do what you want - a way to take a function / model and return a functionalized graph, that also handles other stuff for you like flattening input pytrees, and lifting module state into graph inputs.
Side note: export will likely perform functionalization automatically soon (this doesn’t happen today though).
From what I saw, you can make it work with current pytorch by passing a aten decomposition table to make_fx so that the make_fx will work on decomposed graph: