Exporting a pullback model using ONNX

Hey, I’m interested in creating and exporting a pullback model using ONNX. The pullback model wraps an existing model and, for a given input, computes the wrapped model’s output and its gradient w.r.t. to the input. (I’m using a virtual environment with Python 3.8.10 and PyTorch 1.10.1.)

The following code achieves what I want within PyTorch:

import torch
import torch.nn as nn


# dummy model we want to wrap
class BaseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, padding=1)
    
    def forward(self, x):
        return self.conv1(x)


# the pullback model we want to export using ONNX
class PullbackModel(nn.Module):
    def __init__(self, forward_model):
        super().__init__()
        self.forward_model = forward_model
    
    def forward(self, x, dy):
        y = self.forward_model(x)
        y.backward(gradient=dy)
        dx = x.grad
        return y, dx

While within PyTorch, I can now compute exactly what I want (ignoring the possibly accumulating gradients for now), i.e.:

model = BaseModel()
pb_model = PullbackModel(model)

x = torch.randn(1, 2, 5, 5)  # example input
dy = torch.ones(1, 1, 5, 5)  # example output sensitivity 

# compute the wrapped model's output and the input sensitivity
# (i.e., the gradient w.r.t. the input and the given output sensitivity)
y, dx = pb_model(x, dy)

This works as expected but for my use-case I need to export the pullback model using ONNX for later inference. However, when I try

torch.onnx.export(pb_model, (x, dy), "pb-model.onnx", input_names=["x", "dy"], output_names=["y", "dx"])

I get the following error:

Traceback (most recent call last):
  File "pullback_onnx.py", line 117, in <module>
    torch.onnx.export(pb_model, (x, dy), "pb-model.onnx", input_names=["x", "dy"], output_names=["y", "dx"])
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/__init__.py", line 316, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 107, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 724, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 493, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 437, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 388, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "~/my-venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "~/my-venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/my-venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "~/my-venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "~/my-venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/my-venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "pullback_onnx.py", line 94, in forward
    y.backward(gradient=dy)
  File "~/my-venv/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "~/my-venv/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Found an unsupported argument type in the JIT tracer. File a bug report.

As there seemed to be a problem with the trace during torch.onnx.export, I tried converting the pullback model to a ScriptModule before exporting to ONNX.
For that, I directly traced the wrapped model during the pullback model’s __init__ using

    def __init__(self, forward_model):
        super().__init__()
        self.forward_model = torch.jit.trace(forward_model, torch.rand(1, 2, 5, 5))

and converted pb_model to a ScriptModule using

scripted_module = torch.jit.script(pb_model, (x, dy))

before trying to export the model again:

torch.onnx.export(scripted_module, (x, dy), "pb-model.onnx", input_names=["x", "dy"], output_names=["y", "dx"])

Now, I get a different error:

  File "pullback_onnx.py", line 118, in <module>
    torch.onnx.export(scripted_module, (x, dy), "pb-model.onnx", input_names=["x", "dy"], output_names=["y", "dx"])
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/__init__.py", line 316, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 107, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 724, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 497, in _model_to_graph
    graph = _optimize_graph(graph, operator_export_type,
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 216, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/__init__.py", line 373, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 1137, in _run_symbolic_function
    symbolic_fn = _find_symbolic_in_registry(domain, symbolic_name, opset_version,
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 982, in _find_symbolic_in_registry
    return sym_registry.get_registered_op(op_name, domain, opset_version)
  File "~/my-venv/lib/python3.8/site-packages/torch/onnx/symbolic_registry.py", line 125, in get_registered_op
    raise RuntimeError(msg)
RuntimeError: Exporting the operator prim_grad to ONNX opset version 9 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

I tried exporting using different opset versions which didn’t change the outcome. Looking up the “prim_grad” operator also didn’t yield any helpful results.

Is exporting such a model with ONNX even possible? If yes, what changes would be needed?

Thank you in advance!

PS: As this is my first topic, I’m not sure which category this question should belong to and if I provided all necessary information. Please let me know if there is something missing! :slight_smile:

1 Like