How to save/load a model with torch.compile

I have a model compiled with torch.compile, and I found torch.compile will add a prefix ‘_orig_mod.’ to state_dict() of the model. However, I expect loading these weights to a non compiled model, so I have to remove this prefix manually.

My question is why adding this prefix? What is best practice playing with torch.compile when saving/loading models.

1 Like

For now save or load the uncompiled model - they share weights

longer answer here Make compiled models serializable · Issue #101107 · pytorch/pytorch · GitHub