I just read about PyTorch 2.0, specifically AOTAutograd and training optimization. I’m curious how exactly is training optimized? Here’s my very high-level speculation:
-
torch.compile
traces the forward and backward computation graph (in some IR format) - at some point, the graph goes through optimizations like fusion (or even lowering to some backend?)
- now we can use the result of (2) to run the forward & backward pass, and benefit from the optimizations.
And some questions:
- How does
torch.compile(model)
trace the graph? I don’t think it has access to the backward pass yet. So does it trace the backward pass lazily? - When is the optimization triggered? Do we cache the optimized graph for the next training iteration?
If there’s a better place to post this question, please let me know as well:).