State_dict after compile and onnx.export

Ahoi,

I have a question regarding the difference between a “vanilla” model’s state_dict and the state_dict of the compiled model: How do they relate? It seems to me that the state_dict of the compiled model consists of the same parameters, just with a prefix “_orig_mod.”

import copy
import torch
from torchvision.models import mobilenet_v3_small

net = mobilenet_v3_small()
s0 = copy.deepcopy(net.state_dict())
compiled_net = torch.compile(net)
s1 = copy.deepcopy(compiled_net.state_dict())

for a,b in zip(s0.values(), s1.values()):
    if torch.any(a != b):
        print("different")

What is the rationale behind this? Can I securely remove the “_orig_mod.”-prefix and use the parameters in the pre-compile model? I expect not but I would like to know it for sure.

The reason I am interested in this is that I have trained on a compiled model and wanted to export it using torch.onnx.export, which gives me an error that does not arise if I try to export a model that has not been compiled. So it would be nice to use the learned parameters in the “vanilla” model.

The code I am using goes like this

net = mobilenet_v3_small()
net = net.to(DEV)
net = torch.compile(net)
net.load_state_dict(torch.load(parameterfile,
                               map_location=torch.device(DEV)))

net.eval()

inp = torch.randn([1,3,H,W], device=DEV)
out = net(inp)

torch.onnx.export(net,
                  inp,
                  outfile)

and results in (removing most of the backtraces):

============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
...
  File "/home/drfk/venv/lib/python3.8/site-packages/torch/_ops.py", line 354, in _get_dispatch
    assert key not in self._dispatch_cache, f"{self} {key}"
AssertionError: aten.empty_strided.default DispatchKey.BackendSelect

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
....
  File "/home/drfk/venv/lib/python3.8/site-packages/torch/_ops.py", line 109, in resolve_key
    raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
NotImplementedError: could not find kernel for aten._local_scalar_dense.default at dispatch key DispatchKey.Meta

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
...
  File "/home/drfk/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1173, in get_fake_value
    raise TorchRuntimeError() from e
torch._dynamo.exc.TorchRuntimeError: 

from user code:
   File "/home/drfk/venv/lib/python3.8/site-packages/torchvision/models/mobilenetv3.py", line 220, in forward
    return self._forward_impl(x)
  File "/home/drfk/venv/lib/python3.8/site-packages/torchvision/models/mobilenetv3.py", line 210, in _forward_impl
    x = self.features(x)

Any insights into any of this would be well appreciated.
Cheers

We’re still debating what this should look like here Save/Load OptimizedModule by msaroufim · Pull Request #101651 · pytorch/pytorch · GitHub

But for now to unblock you, I would suggest avoiding the deepcopies and instead saving the original uncompiled models the state dicts should be shared between the compiled and uncompiled model

Thanks.

As you pointed out in the github issue, it is indeed confusing to have the state_dict keys prepended with _orig_mod, especially as I was unable to find any documentation regarding this. I was also discussing this with colleagues wondering about that, too; so it is not just me :wink:

I’ll keep an eye on the github issue. But for now I assume it is ok to use the _orig_mod state-dict to load into an uncompiled model. Anyway, the results seem to match between compiled and uncompiled model.

1 Like