Hey, I’m interested in creating and exporting a pullback model using ONNX. The pullback model wraps an existing model and, for a given input, computes the wrapped model’s output and its gradient w.r.t. to the input. (I’m using a virtual environment with Python 3.8.10 and PyTorch 1.10.1.)
The following code achieves what I want within PyTorch:
import torch
import torch.nn as nn
# dummy model we want to wrap
class BaseModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, padding=1)
def forward(self, x):
return self.conv1(x)
# the pullback model we want to export using ONNX
class PullbackModel(nn.Module):
def __init__(self, forward_model):
super().__init__()
self.forward_model = forward_model
def forward(self, x, dy):
y = self.forward_model(x)
y.backward(gradient=dy)
dx = x.grad
return y, dx
While within PyTorch, I can now compute exactly what I want (ignoring the possibly accumulating gradients for now), i.e.:
model = BaseModel()
pb_model = PullbackModel(model)
x = torch.randn(1, 2, 5, 5) # example input
dy = torch.ones(1, 1, 5, 5) # example output sensitivity
# compute the wrapped model's output and the input sensitivity
# (i.e., the gradient w.r.t. the input and the given output sensitivity)
y, dx = pb_model(x, dy)
This works as expected but for my use-case I need to export the pullback model using ONNX for later inference. However, when I try
torch.onnx.export(pb_model, (x, dy), "pb-model.onnx", input_names=["x", "dy"], output_names=["y", "dx"])
I get the following error:
Traceback (most recent call last):
File "pullback_onnx.py", line 117, in <module>
torch.onnx.export(pb_model, (x, dy), "pb-model.onnx", input_names=["x", "dy"], output_names=["y", "dx"])
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/__init__.py", line 316, in export
return utils.export(model, args, f, export_params, verbose, training,
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 107, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 724, in _export
_model_to_graph(model, args, verbose, input_names,
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 493, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 437, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 388, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "~/my-venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "~/my-venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "~/my-venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "~/my-venv/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "~/my-venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "~/my-venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "pullback_onnx.py", line 94, in forward
y.backward(gradient=dy)
File "~/my-venv/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "~/my-venv/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
Variable._execution_engine.run_backward(
RuntimeError: Found an unsupported argument type in the JIT tracer. File a bug report.
As there seemed to be a problem with the trace during torch.onnx.export
, I tried converting the pullback model to a ScriptModule
before exporting to ONNX.
For that, I directly traced the wrapped model during the pullback model’s __init__
using
def __init__(self, forward_model):
super().__init__()
self.forward_model = torch.jit.trace(forward_model, torch.rand(1, 2, 5, 5))
and converted pb_model
to a ScriptModule
using
scripted_module = torch.jit.script(pb_model, (x, dy))
before trying to export the model again:
torch.onnx.export(scripted_module, (x, dy), "pb-model.onnx", input_names=["x", "dy"], output_names=["y", "dx"])
Now, I get a different error:
File "pullback_onnx.py", line 118, in <module>
torch.onnx.export(scripted_module, (x, dy), "pb-model.onnx", input_names=["x", "dy"], output_names=["y", "dx"])
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/__init__.py", line 316, in export
return utils.export(model, args, f, export_params, verbose, training,
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 107, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 724, in _export
_model_to_graph(model, args, verbose, input_names,
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 497, in _model_to_graph
graph = _optimize_graph(graph, operator_export_type,
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 216, in _optimize_graph
graph = torch._C._jit_pass_onnx(graph, operator_export_type)
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/__init__.py", line 373, in _run_symbolic_function
return utils._run_symbolic_function(*args, **kwargs)
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 1137, in _run_symbolic_function
symbolic_fn = _find_symbolic_in_registry(domain, symbolic_name, opset_version,
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 982, in _find_symbolic_in_registry
return sym_registry.get_registered_op(op_name, domain, opset_version)
File "~/my-venv/lib/python3.8/site-packages/torch/onnx/symbolic_registry.py", line 125, in get_registered_op
raise RuntimeError(msg)
RuntimeError: Exporting the operator prim_grad to ONNX opset version 9 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.
I tried exporting using different opset versions which didn’t change the outcome. Looking up the “prim_grad” operator also didn’t yield any helpful results.
Is exporting such a model with ONNX even possible? If yes, what changes would be needed?
Thank you in advance!
PS: As this is my first topic, I’m not sure which category this question should belong to and if I provided all necessary information. Please let me know if there is something missing!