Torch.compile - what is the best scope of compilation?

I am looking for clarification on the best point to wrap a model in torch.compile, relative to other wrappers / calls. Some previous answers stated it was subject to change in newer releases, so I am looking for information about the current state of things.
Here’s a summary of what I’ve seen so far:

  1. DDP:
    According to Distributed Data Parallel — PyTorch main documentation, the DDP wrapper should definitely be applied first, as a requirement for DDPOptimizer to work properly.
    In contrast, according to Torch.compile() before or after .cuda() answer by @marksaroufim, compile() should be called first on the inner module.
    This is reiterated in How should I use torch.compile properly? - #2 by marksaroufim
    Which of these is correct and how big is the impact?

  2. .cuda()
    According to the answers by Mark linked above, .cuda() should be called first before compile(cuda_model). Is the impact noticeable? The second thread states it shouldn’t be a huge deal, but I’d like to make sure.

  3. AMP
    I can see two main ways:

@torch.compile
def uses_autocast(model, input):
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        output=model(input)

and

model = torch.compile(model)
with torch.autocast(device_type='cuda', dtype=torch.float16):
    output=model(input)

Per Interaction of torch.no_grad and torch.autocast context managers with torch.compile · Issue #100241 · pytorch/pytorch · GitHub, the second one seems to be recommended, as the graph breaks on context manager entry/exit.

Is this still valid for new versions? Are any changes expected here?

  1. General extent of compilation
    The two scenarios here are
model = torch.compile(model)

for input, target in data:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

vs

@torch.compile
def train_step(...):
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

for input, target in data:
    train_step(...)

As the issue linked above mentions, theoretically the bigger the extent of compilation the better, but the real results may vary.
In my experiments, compiling the whole train step functions sometimes yields improved performance.
However, the docs and tutorials mainly refer to compiling the module, or just a standalone function (not a training step).
Is compiling the whole step recommended / supported? Is there official guidance on the best extent of compilation? (In my experience, including dataloading gets messy).

Thanks in advance!

1 Like

These are interesting and good questions! I’m unaware of a “best practice” guide, but I’m sure @marksaroufim would know more about the proper workflow to use torch.compile.

Your sense is correct, basically the more things dynamo has to worry about the more likely things will be slower

Answers to your specific question

  1. Distributed: Ideally torch.compile() should optimize communication passes but because it doesn’t do so quite yet and because AOT autograd unrolls both the forward and backward pass, you can’t overlap communication effectively. This was fixed for DDP but not yet for FSDP for example so my suggestion in the past was outdated for DDP but still correct for FSDP
  2. .cuda() calls: any calls that changes devices is something torch.compile needs to reason about so it’s best to do these before compilation but it’s not necessary. This today messes up cudagraphs and you can read more here multi gpu - Is changing the device in a CUDA Graph node unavailable? - Stack Overflow
  3. AMP: You’re right ideally compile within context managers. Support for this will generally be better in nightlies and things move too quickly for me to have an authorative view as to what to recommend, basically try the context manager and if things are too slow or break then compile at amore granular level
  4. Compiling functions vs modules: When 2.0 was released support for optimizers wasn’t quite there but now it is so yes you should optimize the training loop. It’s also critical to compile funcitons instead of modules if your module doesnt call the forward() or __call__() functions, for example HuggingFace models tend to use generate in which case you want to torch.compile(model.generate)). I haven’t personally looked at compiling data loading yet, it’s interesting so if you find specific issues please lmk

Hi @marksaroufim, thank you for your reply! This basically explains things I am looking for. I have a few questions regarding question #4, compiling models vs compiling the training step:

Compiling functions vs modules: When 2.0 was released support for optimizers wasn’t quite there but now it is so yes you should optimize the training loop.

  1. I think 2.0 is the latest stable version as of now? Are support for optimizers improved in nightly?
  2. If I compiled the model rather than the entire training step, would the loss.backward() also be JIT compiled? I guess the backward pass is just a reverse of the forward pass, so it should benefit from the graph optimization done in the forward pass as well?

Thanks in advance for answering my questions!

  1. Stable is 2.0.1 now, optimizer support is better there but tbh still best in nightlies
  2. Yes that’s the intent

Thanks a lot @marksaroufim! Really appreciate the answer.

@marksaroufim what’s the state of things now. I feel like this should go into the FAQ, which talks about DDP but never makes a suggestion for what to actually do

For DDP and compile:

(1) running with them together by default should work (if not, let us know!). The “default” integration basically takes the graph that we capture, splits it into buckets/regions, and compiles each bucket separately - the idea being that we can issue all-reduces in the backward that overlap with each (compiled) bucket.

The main benefit of this integration is that it’s pretty simple (if you are seeing speedups compared to eager, then great). One potential drawback, though is that compile won’t be able to fuse things across buckets. In practice, this may or may not actually matter for getting good performance.

(2) There is an experimental config you can toggle to try to get better performance, if you’re willing to try out some experimental features that are still being hardened: torch._dynamo.config.optimize_ddp = "python_reducer" (link).

That config basically ensures that the all-reduces in the backward show up inside of the compiled graph, so compile can get one giant backward graph to optimize - containing all of the backward ops, and the all-reduce communications. The main reason it’s not on by default is that it requires “compiled autograd” (Ed has a great podcast on it here). Compiled autograd is another feature in compile that effectively allows us to compile more of the backward than we do ordinarily - it can capture stuff like AccumulateGrad nodes in the backward - although it is currently undergoing a re-write and being further hardened. If you’re willing to try it out, you can do so with:

# compile your model
m = torch.compile(DDP(m))
# run the forward/backward under compiled autograd
with torch._dynamo.utils.maybe_enable_compiled_autograd(True):
    fw_out = m(inp)
    loss = loss_fn(fw_out, ..)
    loss.backward()