Ahoi,
I have a question regarding the difference between a “vanilla” model’s state_dict and the state_dict of the compiled model: How do they relate? It seems to me that the state_dict of the compiled model consists of the same parameters, just with a prefix “_orig_mod.”
import copy
import torch
from torchvision.models import mobilenet_v3_small
net = mobilenet_v3_small()
s0 = copy.deepcopy(net.state_dict())
compiled_net = torch.compile(net)
s1 = copy.deepcopy(compiled_net.state_dict())
for a,b in zip(s0.values(), s1.values()):
if torch.any(a != b):
print("different")
What is the rationale behind this? Can I securely remove the “_orig_mod.”-prefix and use the parameters in the pre-compile model? I expect not but I would like to know it for sure.
The reason I am interested in this is that I have trained on a compiled model and wanted to export it using torch.onnx.export, which gives me an error that does not arise if I try to export a model that has not been compiled. So it would be nice to use the learned parameters in the “vanilla” model.
The code I am using goes like this
net = mobilenet_v3_small()
net = net.to(DEV)
net = torch.compile(net)
net.load_state_dict(torch.load(parameterfile,
map_location=torch.device(DEV)))
net.eval()
inp = torch.randn([1,3,H,W], device=DEV)
out = net(inp)
torch.onnx.export(net,
inp,
outfile)
and results in (removing most of the backtraces):
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================
Traceback (most recent call last):
...
File "/home/drfk/venv/lib/python3.8/site-packages/torch/_ops.py", line 354, in _get_dispatch
assert key not in self._dispatch_cache, f"{self} {key}"
AssertionError: aten.empty_strided.default DispatchKey.BackendSelect
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
....
File "/home/drfk/venv/lib/python3.8/site-packages/torch/_ops.py", line 109, in resolve_key
raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
NotImplementedError: could not find kernel for aten._local_scalar_dense.default at dispatch key DispatchKey.Meta
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
...
File "/home/drfk/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1173, in get_fake_value
raise TorchRuntimeError() from e
torch._dynamo.exc.TorchRuntimeError:
from user code:
File "/home/drfk/venv/lib/python3.8/site-packages/torchvision/models/mobilenetv3.py", line 220, in forward
return self._forward_impl(x)
File "/home/drfk/venv/lib/python3.8/site-packages/torchvision/models/mobilenetv3.py", line 210, in _forward_impl
x = self.features(x)
Any insights into any of this would be well appreciated.
Cheers