We are evaluating distributed training for PT 2.0 with compilation. We noticed that compiling a ~ 1B model will cause the first few steps to be slower and it can take ~10 mins for training to reach stable and full throughput state. I am wondering if a compiled model can be saved as some intermediate format so that re-launching training with the same model will take less time.
As a workaround for the currently missing savable cache function, would it be possible to do a manual workaround by pickling that cache? As Zhaoqi asks, where is the cache kept currently? Is it of the class fields?
Also, this is potentially out-of-scope in this question, but I would be interested in a “save cache” feature for TorchScript too. It there anything like that already implemented?