[onnx export] toy model export to onnx

Why this model can not export to onnx? After traced some keys seem missing.

Keys

before trace
odict_keys(['weight', 'conv2d.weight', 'conv2d.bias'])

after trace
odict_keys(['weight', 'conv2d.bias'])

Code example

import torch
from torch import nn
import torch.nn.functional as F


class toy_model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv2d = nn.Conv2d(3, 8, 3)
        self.weight = nn.Parameter(torch.rand(8, 3, 3, 3))


    def forward(self, x):
        self.weight = nn.Parameter(torch.rand(8, 6, 3, 3))
        self.conv2d.weight = self.weight
        self.conv2d.in_channels = 6

        x = self.conv2d(x)

        return x


x = torch.rand((1, 6, 64, 64))

model = toy_model()

torch.onnx.export(
    model,
    x,
    'toy.onnx',
    verbose=True,
    opset_version=14,
    input_names=['input'],
    output_names=['output'],
)

Error

Traceback (most recent call last):
  File "/code/toy_test.py", line 97, in <module>
    torch.onnx.export(
  File "/home/sota/miniconda3/envs/sota/lib/python3.9/site-packages/torch/onnx/__init__.py", line 316, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/home/sota/miniconda3/envs/sota/lib/python3.9/site-packages/torch/onnx/utils.py", line 107, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/home/sota/miniconda3/envs/sota/lib/python3.9/site-packages/torch/onnx/utils.py", line 752, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/home/sota/miniconda3/envs/sota/lib/python3.9/site-packages/torch/onnx/utils.py", line 521, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/sota/miniconda3/envs/sota/lib/python3.9/site-packages/torch/onnx/utils.py", line 465, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/sota/miniconda3/envs/sota/lib/python3.9/site-packages/torch/onnx/utils.py", line 420, in _trace_and_get_graph_from_model
    raise RuntimeError("state_dict changed after running the tracer; "
RuntimeError: state_dict changed after running the tracer; something weird is happening in your model!

onnx.export will trace the model, i.e. it’ll execute the forward pass using the provided input data and will record all operations.
In your forward method you are replacing internals of the model, which will most likely break.
I don’t know if the scripting approach would work, but it might be worth a try.

@ptrblck Did you know how to solve it or any suggestions?

You could try to script your model, but I also don’t know if this would solve the issue or if ONNX doesn’t support model manipulation in the forward pass.