The forward()
function of my torch model has multiple arguments. For export the model, I use torch.onnx.export()
function. Before export the model, I want to freeze some arguments of the model’s forward()
func and only input the left arguments which are the indeed inputs of the model.
But torch.onnx.export()
give a error: TypeError: forward() got multiple values for argument ...
.
Codes like:
from functools import partial
import torch
class MyModel(torch.nn.Module):
"""A simple model for test."""
def __init__(self) -> None:
super().__init__()
self.weight = 1 # the weight of the model
def forward(self, img, a, b, **kwargs):
print(f"Got {img=}, {a=}, {b=}")
print(f"Got kwargs: {kwargs}")
return img * self.weight + a + b
if __name__ == "__main__":
my_model = MyModel()
old_f = my_model.forward
my_model.forward = partial(
my_model.forward,
a = 10,
b = 100,
ok = False
)
my_model.eval()
ret = my_model(1) # 1 * weight(=1) = 1
print(f"{ret=}") # 111
torch.onnx.export(
my_model,
args=(
1,
"""{
"a": 10,
"b": 100,
"ok": False
},"""
),
f="model.onnx"
)
And the outputs and error:
Got img=1, a=10, b=100
Got kwargs: {'ok': False}
ret=111
Traceback (most recent call last):
File "ptest.py", line 34, in <module>
torch.onnx.export(
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/__init__.py", line 350, in export
return utils.export(
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 163, in export
_export(
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 1074, in _export
(env) frank@inspur-dev:~/spot$ python ptest.py
Got img=1, a=10, b=100
Got kwargs: {'ok': False}
ret=111
Traceback (most recent call last):
File "ptest.py", line 33, in <module>
torch.onnx.export(
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/__init__.py", line 350, in export
return utils.export(
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 163, in export
_export(
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 1074, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 727, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 602, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 517, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 1175, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() got multiple values for argument 'a'
Should one use partia
l like this way to freeze arguments of forward
?
Or have other way to work?