Is there any way to save multiple graphs from same model without redundantly saving their state dicts?

Hi, I want to save prefill / decode phase graphs of generative model (e.g., llama3.1) using torch.export.save. It looks like prefill / decode phase cannot be represented by one graph, so I called torch.export.export and torch.export.save twice each and this saves model’s state dict twice. I want to eliminate this redundancy. Is there any way to save model’s state dict once and share it between two exported graphs?

I’m doing the same thing, do you have any solution?

I think there is no any dirrect option provided by pytorch for duplicate the state_dict across multiple graphs. so this will export multiple phases of the same model causing unnecesary left overs.