How is training optimized

I just read about PyTorch 2.0, specifically AOTAutograd and training optimization. I’m curious how exactly is training optimized? Here’s my very high-level speculation:

  1. torch.compile traces the forward and backward computation graph (in some IR format)
  2. at some point, the graph goes through optimizations like fusion (or even lowering to some backend?)
  3. now we can use the result of (2) to run the forward & backward pass, and benefit from the optimizations.

And some questions:

  1. How does torch.compile(model) trace the graph? I don’t think it has access to the backward pass yet. So does it trace the backward pass lazily?
  2. When is the optimization triggered? Do we cache the optimized graph for the next training iteration?

If there’s a better place to post this question, please let me know as well:).

1 Like

Your speculation is mostly correct for 1,2 and for your questions optimizations are cached and only redone if some guard assumptions are violated Guards Overview — PyTorch master documentation and for 1 @chillee will probably explain it better than me

1 Like

We trace out the backwards pass “ahead of time” (hence AOTAutograd). Although we are not actually able to execute the backwards pass until later (when we have gradOut), we know the shape of gradOut, and thus can trace it out “ahead of time”.

You can read more about it here: AOTAutograd - Google Slides

1 Like