I’m trying to use aot_module
to compile forward and backward graph of a model. I want to be able to export the graph to onnx but I’m having trouble figuring out how to do that. Here is an example script
import torch
from functorch.compile import aot_module
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from functorch.compile import make_boxed_func
import torch._C._onnx as _C_onnx
class Linear(torch.nn.Module):
def __init__(self):
super(Linear, self).__init__()
self.linear = torch.nn.Linear(128, 10)
self.activation = torch.nn.ReLU()
def forward(self, *inputs):
input = self.linear(inputs[0])
input = self.activation(input)
return input
def compiler_fn(fx_module: torch.fx.GraphModule, example_inputs):
print(fx_module)
torch.onnx.export(
fx_module,
tuple(example_inputs),
"model.onnx",
training=_C_onnx.TrainingMode.PRESERVE,
do_constant_folding=False,
export_params=False,
)
return make_boxed_func(fx_module)
model = Linear()
model.train()
loss_fn = torch.nn.MSELoss()
with FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()) as fake_mode:
input = torch.randn((64, 128), requires_grad=True)
labels = torch.randn((64, 10), requires_grad=True)
c_model = aot_module(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
output = c_model(input)
loss = loss_fn(output, labels)
loss.backward()
I get
TypeError: f_(): incompatible function arguments. The following argument types are supported:
1. (self: torch._C.Node, arg0: str, arg1: float) -> torch._C.Node
Invoked with: %5 : Tensor = onnx::Gemm(%primals_3, %t, %primals_2), scope: torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl::
, 'alpha', i1
(Occurred when translating addmm).
and have no idea what I’m doing wrong. Any guidance would be appreciated. Thanks!