How do I get `torch.onnx.export` to work with `aot_module`?

I’m trying to use aot_module to compile forward and backward graph of a model. I want to be able to export the graph to onnx but I’m having trouble figuring out how to do that. Here is an example script

import torch
from functorch.compile import aot_module
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from functorch.compile import make_boxed_func
import torch._C._onnx as _C_onnx


class Linear(torch.nn.Module):
    def __init__(self):
        super(Linear, self).__init__()
        self.linear = torch.nn.Linear(128, 10)
        self.activation = torch.nn.ReLU()

    def forward(self, *inputs):
        input = self.linear(inputs[0])
        input = self.activation(input)
        return input


def compiler_fn(fx_module: torch.fx.GraphModule, example_inputs):
    print(fx_module)
    torch.onnx.export(
        fx_module,
        tuple(example_inputs),
        "model.onnx",
        training=_C_onnx.TrainingMode.PRESERVE,
        do_constant_folding=False,
        export_params=False,
    )
    return make_boxed_func(fx_module)


model = Linear()
model.train()
loss_fn = torch.nn.MSELoss()

with FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()) as fake_mode:
    input = torch.randn((64, 128), requires_grad=True)
    labels = torch.randn((64, 10), requires_grad=True)

    c_model = aot_module(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn)
    output = c_model(input)
    loss = loss_fn(output, labels)
    loss.backward()

I get

TypeError: f_(): incompatible function arguments. The following argument types are supported:
    1. (self: torch._C.Node, arg0: str, arg1: float) -> torch._C.Node

Invoked with: %5 : Tensor = onnx::Gemm(%primals_3, %t, %primals_2), scope: torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl::
, 'alpha', i1 
(Occurred when translating addmm).

and have no idea what I’m doing wrong. Any guidance would be appreciated. Thanks!

1 Like

So, if i turn bias off in LinearLayer i get further. This gets passed the first error on the forward graph. Now I get something like

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::threshold_backward' to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

So is the first error i showed in the original post a bug then?