Torch.compile - what is the best scope of compilation?

I am looking for clarification on the best point to wrap a model in torch.compile, relative to other wrappers / calls. Some previous answers stated it was subject to change in newer releases, so I am looking for information about the current state of things.
Here’s a summary of what I’ve seen so far:

  1. DDP:
    According to Distributed Data Parallel — PyTorch main documentation, the DDP wrapper should definitely be applied first, as a requirement for DDPOptimizer to work properly.
    In contrast, according to Torch.compile() before or after .cuda() answer by @marksaroufim, compile() should be called first on the inner module.
    This is reiterated in How should I use torch.compile properly? - #2 by marksaroufim
    Which of these is correct and how big is the impact?

  2. .cuda()
    According to the answers by Mark linked above, .cuda() should be called first before compile(cuda_model). Is the impact noticeable? The second thread states it shouldn’t be a huge deal, but I’d like to make sure.

  3. AMP
    I can see two main ways:

@torch.compile
def uses_autocast(model, input):
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        output=model(input)

and

model = torch.compile(model)
with torch.autocast(device_type='cuda', dtype=torch.float16):
    output=model(input)

Per Interaction of torch.no_grad and torch.autocast context managers with torch.compile · Issue #100241 · pytorch/pytorch · GitHub, the second one seems to be recommended, as the graph breaks on context manager entry/exit.

Is this still valid for new versions? Are any changes expected here?

  1. General extent of compilation
    The two scenarios here are
model = torch.compile(model)

for input, target in data:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

vs

@torch.compile
def train_step(...):
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

for input, target in data:
    train_step(...)

As the issue linked above mentions, theoretically the bigger the extent of compilation the better, but the real results may vary.
In my experiments, compiling the whole train step functions sometimes yields improved performance.
However, the docs and tutorials mainly refer to compiling the module, or just a standalone function (not a training step).
Is compiling the whole step recommended / supported? Is there official guidance on the best extent of compilation? (In my experience, including dataloading gets messy).

Thanks in advance!

These are interesting and good questions! I’m unaware of a “best practice” guide, but I’m sure @marksaroufim would know more about the proper workflow to use torch.compile.

Your sense is correct, basically the more things dynamo has to worry about the more likely things will be slower

Answers to your specific question

  1. Distributed: Ideally torch.compile() should optimize communication passes but because it doesn’t do so quite yet and because AOT autograd unrolls both the forward and backward pass, you can’t overlap communication effectively. This was fixed for DDP but not yet for FSDP for example so my suggestion in the past was outdated for DDP but still correct for FSDP
  2. .cuda() calls: any calls that changes devices is something torch.compile needs to reason about so it’s best to do these before compilation but it’s not necessary. This today messes up cudagraphs and you can read more here multi gpu - Is changing the device in a CUDA Graph node unavailable? - Stack Overflow
  3. AMP: You’re right ideally compile within context managers. Support for this will generally be better in nightlies and things move too quickly for me to have an authorative view as to what to recommend, basically try the context manager and if things are too slow or break then compile at amore granular level
  4. Compiling functions vs modules: When 2.0 was released support for optimizers wasn’t quite there but now it is so yes you should optimize the training loop. It’s also critical to compile funcitons instead of modules if your module doesnt call the forward() or __call__() functions, for example HuggingFace models tend to use generate in which case you want to torch.compile(model.generate)). I haven’t personally looked at compiling data loading yet, it’s interesting so if you find specific issues please lmk

Hi @marksaroufim, thank you for your reply! This basically explains things I am looking for. I have a few questions regarding question #4, compiling models vs compiling the training step:

Compiling functions vs modules: When 2.0 was released support for optimizers wasn’t quite there but now it is so yes you should optimize the training loop.

  1. I think 2.0 is the latest stable version as of now? Are support for optimizers improved in nightly?
  2. If I compiled the model rather than the entire training step, would the loss.backward() also be JIT compiled? I guess the backward pass is just a reverse of the forward pass, so it should benefit from the graph optimization done in the forward pass as well?

Thanks in advance for answering my questions!

  1. Stable is 2.0.1 now, optimizer support is better there but tbh still best in nightlies
  2. Yes that’s the intent

Thanks a lot @marksaroufim! Really appreciate the answer.