briefly: a fw_compiler gets a graph with aten-ir nodes, but some nodes with node.op == "call_function"
have a node.target
that points to the wrong OpOverload
. Here’s the simplest example possible.
#!/usr/bin/env python3
from typing import Any, Callable
import torch
import torch.fx
from torch._functorch.aot_autograd import aot_module_simplified
def my_fw_compiler(
gm: torch.fx.graph_module.GraphModule,
sample_input: Any,
) -> Callable[..., Any]:
assert isinstance(gm, torch.fx.graph_module.GraphModule)
for node in gm.graph.nodes:
assert isinstance(node, torch.fx.node.Node)
if node.op not in ("placeholder", "output"):
print(node.op, node.target)
print(" schema: ", node.target._schema)
print(" args: ", "(", ", ".join(
f"{type(a).__name__} {a}"
for a in node.args
), ")")
return gm.forward # type: ignore
def my_dynamo_backend(
gm: torch.fx.graph_module.GraphModule,
sample_input: Any,
) -> Callable[..., Any]:
assert isinstance(gm, torch.fx.graph_module.GraphModule)
return aot_module_simplified(
gm,
sample_input,
fw_compiler=my_fw_compiler,
)
class SimplestModule(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input + 5e-1
def main() -> None:
sample_input = torch.ones(4, dtype=torch.float32, device="cpu")
regular_model = SimplestModule()
regular_model.eval()
regular_model_result = regular_model(sample_input)
compiled_model = torch.compile(
regular_model,
backend=my_dynamo_backend,
fullgraph=True,
)
assert isinstance(compiled_model, torch.nn.Module)
compiled_model_result = compiled_model(sample_input)
print()
print(regular_model_result)
print(compiled_model_result)
if __name__ == "__main__":
main()
when running:
$ ./bug_repro.py
call_function aten.add.Tensor
schema: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
args: ( Node arg0_1, float 0.5 )
tensor([1.5000, 1.5000, 1.5000, 1.5000])
tensor([1.5000, 1.5000, 1.5000, 1.5000])
as can be seen, node.target
is the “Tensor” overload, but the 2nd argument in node.args
is clearly not a tensor…
Am I missing the obvious here? maybe such scary terms as boxing / wrappers / trace modes (came across those while digging through the code).
Context: I am toying with a “compiler” that generates some (non-Python) code from the graph along the way, and as part of that I implemented a “function call actual matches argument type” code, and to my surprise it failed.
Cheers!