How to save/load a model with torch.compile

I have a model compiled with torch.compile, and I found torch.compile will add a prefix ‘_orig_mod.’ to state_dict() of the model. However, I expect loading these weights to a non compiled model, so I have to remove this prefix manually.

My question is why adding this prefix? What is best practice playing with torch.compile when saving/loading models.

3 Likes

For now save or load the uncompiled model - they share weights

longer answer here Make compiled models serializable · Issue #101107 · pytorch/pytorch · GitHub

1 Like

In fact, there is a general format that can solve this problem, which is to rewrite the state_dict and load_state_dict functions every time after writing network. Because the functions we write all inherit from nn.Module, custom functions can also be solved in this way, for example:

from collections import OrderedDict
import torch.nn as nn


class CustomNet(nn.Module):
    def __init__(self):
        self.kernel: nn.Module = ...
    # override state_dict() method
    def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
        if destination is None:
            destination = OrderedDict()
        prefix = ""  # remove prefix
        destination.update([('kernel', self.kernel.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)])
        return destination
    # override load_state_dict()
    def load_state_dict(self, state_dict, ...):
        self.kernel.load_state_dict(state_dict)

The above code is just an example of the method and cannot be run. You should adjust the structure accordingly according to your own code.