I am looking for clarification on the best point to wrap a model in torch.compile, relative to other wrappers / calls. Some previous answers stated it was subject to change in newer releases, so I am looking for information about the current state of things.
Here’s a summary of what I’ve seen so far:
.cuda()
According to the answers by Mark linked above, .cuda() should be called first before compile(cuda_model). Is the impact noticeable? The second thread states it shouldn’t be a huge deal, but I’d like to make sure.
AMP
I can see two main ways:
@torch.compile
def uses_autocast(model, input):
with torch.autocast(device_type='cuda', dtype=torch.float16):
output=model(input)
and
model = torch.compile(model)
with torch.autocast(device_type='cuda', dtype=torch.float16):
output=model(input)
Is this still valid for new versions? Are any changes expected here?
General extent of compilation
The two scenarios here are
model = torch.compile(model)
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
vs
@torch.compile
def train_step(...):
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
for input, target in data:
train_step(...)
As the issue linked above mentions, theoretically the bigger the extent of compilation the better, but the real results may vary.
In my experiments, compiling the whole train step functions sometimes yields improved performance.
However, the docs and tutorials mainly refer to compiling the module, or just a standalone function (not a training step).
Is compiling the whole step recommended / supported? Is there official guidance on the best extent of compilation? (In my experience, including dataloading gets messy).
These are interesting and good questions! I’m unaware of a “best practice” guide, but I’m sure @marksaroufim would know more about the proper workflow to use torch.compile.
Your sense is correct, basically the more things dynamo has to worry about the more likely things will be slower
Answers to your specific question
Distributed: Ideally torch.compile() should optimize communication passes but because it doesn’t do so quite yet and because AOT autograd unrolls both the forward and backward pass, you can’t overlap communication effectively. This was fixed for DDP but not yet for FSDP for example so my suggestion in the past was outdated for DDP but still correct for FSDP
AMP: You’re right ideally compile within context managers. Support for this will generally be better in nightlies and things move too quickly for me to have an authorative view as to what to recommend, basically try the context manager and if things are too slow or break then compile at amore granular level
Compiling functions vs modules: When 2.0 was released support for optimizers wasn’t quite there but now it is so yes you should optimize the training loop. It’s also critical to compile funcitons instead of modules if your module doesnt call the forward() or __call__() functions, for example HuggingFace models tend to use generate in which case you want to torch.compile(model.generate)). I haven’t personally looked at compiling data loading yet, it’s interesting so if you find specific issues please lmk
Hi @marksaroufim, thank you for your reply! This basically explains things I am looking for. I have a few questions regarding question #4, compiling models vs compiling the training step:
Compiling functions vs modules: When 2.0 was released support for optimizers wasn’t quite there but now it is so yes you should optimize the training loop.
I think 2.0 is the latest stable version as of now? Are support for optimizers improved in nightly?
If I compiled the model rather than the entire training step, would the loss.backward() also be JIT compiled? I guess the backward pass is just a reverse of the forward pass, so it should benefit from the graph optimization done in the forward pass as well?
@marksaroufim what’s the state of things now. I feel like this should go into the FAQ, which talks about DDP but never makes a suggestion for what to actually do
(1) running with them together by default should work (if not, let us know!). The “default” integration basically takes the graph that we capture, splits it into buckets/regions, and compiles each bucket separately - the idea being that we can issue all-reduces in the backward that overlap with each (compiled) bucket.
The main benefit of this integration is that it’s pretty simple (if you are seeing speedups compared to eager, then great). One potential drawback, though is that compile won’t be able to fuse things across buckets. In practice, this may or may not actually matter for getting good performance.
(2) There is an experimental config you can toggle to try to get better performance, if you’re willing to try out some experimental features that are still being hardened: torch._dynamo.config.optimize_ddp = "python_reducer" (link).
That config basically ensures that the all-reduces in the backward show up inside of the compiled graph, so compile can get one giant backward graph to optimize - containing all of the backward ops, and the all-reduce communications. The main reason it’s not on by default is that it requires “compiled autograd” (Ed has a great podcast on it here). Compiled autograd is another feature in compile that effectively allows us to compile more of the backward than we do ordinarily - it can capture stuff like AccumulateGrad nodes in the backward - although it is currently undergoing a re-write and being further hardened. If you’re willing to try it out, you can do so with:
# compile your model
m = torch.compile(DDP(m))
# run the forward/backward under compiled autograd
with torch._dynamo.utils.maybe_enable_compiled_autograd(True):
fw_out = m(inp)
loss = loss_fn(fw_out, ..)
loss.backward()