How to use `partial` function to freeze some arguments of `forward` of a PyTorch model for onnx export

The forward() function of my torch model has multiple arguments. For export the model, I use torch.onnx.export() function. Before export the model, I want to freeze some arguments of the model’s forward() func and only input the left arguments which are the indeed inputs of the model.

But torch.onnx.export() give a error: TypeError: forward() got multiple values for argument ....

Codes like:

from functools import partial

import torch


class MyModel(torch.nn.Module):
    """A simple model for test."""

    def __init__(self) -> None:
        super().__init__()
        self.weight = 1  # the weight of the model
    
    def forward(self, img, a, b, **kwargs):
        print(f"Got {img=}, {a=}, {b=}")
        print(f"Got kwargs: {kwargs}")
        return img * self.weight + a + b


if __name__ == "__main__":
    my_model = MyModel()

    old_f = my_model.forward
    my_model.forward = partial(
        my_model.forward,
        a = 10,
        b = 100,
        ok = False
    )
    
    my_model.eval()
    ret = my_model(1)  # 1 * weight(=1) = 1
    print(f"{ret=}")  # 111
    torch.onnx.export(
        my_model,
        args=(
            1,
            """{
                "a": 10,
                "b": 100,
                "ok": False
            },"""
        ),
        f="model.onnx"
    )

And the outputs and error:

Got img=1, a=10, b=100
Got kwargs: {'ok': False}
ret=111
Traceback (most recent call last):
  File "ptest.py", line 34, in <module>
    torch.onnx.export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/__init__.py", line 350, in export
    return utils.export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 163, in export
    _export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 1074, in _export
(env) frank@inspur-dev:~/spot$ python ptest.py 
Got img=1, a=10, b=100
Got kwargs: {'ok': False}
ret=111
Traceback (most recent call last):
  File "ptest.py", line 33, in <module>
    torch.onnx.export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/__init__.py", line 350, in export
    return utils.export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 163, in export
    _export(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 1074, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 727, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 602, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/onnx/utils.py", line 517, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 1175, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/general-user/frank/spot/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
    result = self.forward(*input, **kwargs)
TypeError: forward() got multiple values for argument 'a'

Should one use partial like this way to freeze arguments of forward?
Or have other way to work?

1 Like