I am looking for clarification on the best point to wrap a model in torch.compile, relative to other wrappers / calls. Some previous answers stated it was subject to change in newer releases, so I am looking for information about the current state of things.
Here’s a summary of what I’ve seen so far:
.cuda()
According to the answers by Mark linked above, .cuda() should be called first before compile(cuda_model). Is the impact noticeable? The second thread states it shouldn’t be a huge deal, but I’d like to make sure.
AMP
I can see two main ways:
@torch.compile
def uses_autocast(model, input):
with torch.autocast(device_type='cuda', dtype=torch.float16):
output=model(input)
and
model = torch.compile(model)
with torch.autocast(device_type='cuda', dtype=torch.float16):
output=model(input)
Is this still valid for new versions? Are any changes expected here?
General extent of compilation
The two scenarios here are
model = torch.compile(model)
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
vs
@torch.compile
def train_step(...):
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
for input, target in data:
train_step(...)
As the issue linked above mentions, theoretically the bigger the extent of compilation the better, but the real results may vary.
In my experiments, compiling the whole train step functions sometimes yields improved performance.
However, the docs and tutorials mainly refer to compiling the module, or just a standalone function (not a training step).
Is compiling the whole step recommended / supported? Is there official guidance on the best extent of compilation? (In my experience, including dataloading gets messy).
These are interesting and good questions! I’m unaware of a “best practice” guide, but I’m sure @marksaroufim would know more about the proper workflow to use torch.compile.
Your sense is correct, basically the more things dynamo has to worry about the more likely things will be slower
Answers to your specific question
Distributed: Ideally torch.compile() should optimize communication passes but because it doesn’t do so quite yet and because AOT autograd unrolls both the forward and backward pass, you can’t overlap communication effectively. This was fixed for DDP but not yet for FSDP for example so my suggestion in the past was outdated for DDP but still correct for FSDP
AMP: You’re right ideally compile within context managers. Support for this will generally be better in nightlies and things move too quickly for me to have an authorative view as to what to recommend, basically try the context manager and if things are too slow or break then compile at amore granular level
Compiling functions vs modules: When 2.0 was released support for optimizers wasn’t quite there but now it is so yes you should optimize the training loop. It’s also critical to compile funcitons instead of modules if your module doesnt call the forward() or __call__() functions, for example HuggingFace models tend to use generate in which case you want to torch.compile(model.generate)). I haven’t personally looked at compiling data loading yet, it’s interesting so if you find specific issues please lmk
Hi @marksaroufim, thank you for your reply! This basically explains things I am looking for. I have a few questions regarding question #4, compiling models vs compiling the training step:
Compiling functions vs modules: When 2.0 was released support for optimizers wasn’t quite there but now it is so yes you should optimize the training loop.
I think 2.0 is the latest stable version as of now? Are support for optimizers improved in nightly?
If I compiled the model rather than the entire training step, would the loss.backward() also be JIT compiled? I guess the backward pass is just a reverse of the forward pass, so it should benefit from the graph optimization done in the forward pass as well?
@marksaroufim what’s the state of things now. I feel like this should go into the FAQ, which talks about DDP but never makes a suggestion for what to actually do