I have a model compiled with torch.compile, and I found torch.compile will add a prefix ‘_orig_mod.’ to state_dict() of the model. However, I expect loading these weights to a non compiled model, so I have to remove this prefix manually.
My question is why adding this prefix? What is best practice playing with torch.compile when saving/loading models.
In fact, there is a general format that can solve this problem, which is to rewrite the state_dict and load_state_dict functions every time after writing network. Because the functions we write all inherit from nn.Module, custom functions can also be solved in this way, for example:
from collections import OrderedDict
import torch.nn as nn
class CustomNet(nn.Module):
def __init__(self):
self.kernel: nn.Module = ...
# override state_dict() method
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if destination is None:
destination = OrderedDict()
prefix = "" # remove prefix
destination.update([('kernel', self.kernel.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)])
return destination
# override load_state_dict()
def load_state_dict(self, state_dict, ...):
self.kernel.load_state_dict(state_dict)
The above code is just an example of the method and cannot be run. You should adjust the structure accordingly according to your own code.