I am looking for clarification on the best point to wrap a model in torch.compile, relative to other wrappers / calls. Some previous answers stated it was subject to change in newer releases, so I am looking for information about the current state of things.
Here’s a summary of what I’ve seen so far:
-
DDP:
According to Distributed Data Parallel — PyTorch main documentation, the DDP wrapper should definitely be applied first, as a requirement for DDPOptimizer to work properly.
In contrast, according to Torch.compile() before or after .cuda() answer by @marksaroufim, compile() should be called first on the inner module.
This is reiterated in How should I use torch.compile properly? - #2 by marksaroufim
Which of these is correct and how big is the impact? -
.cuda()
According to the answers by Mark linked above, .cuda() should be called first before compile(cuda_model). Is the impact noticeable? The second thread states it shouldn’t be a huge deal, but I’d like to make sure. -
AMP
I can see two main ways:
@torch.compile
def uses_autocast(model, input):
with torch.autocast(device_type='cuda', dtype=torch.float16):
output=model(input)
and
model = torch.compile(model)
with torch.autocast(device_type='cuda', dtype=torch.float16):
output=model(input)
Per Interaction of torch.no_grad and torch.autocast context managers with torch.compile · Issue #100241 · pytorch/pytorch · GitHub, the second one seems to be recommended, as the graph breaks on context manager entry/exit.
Is this still valid for new versions? Are any changes expected here?
- General extent of compilation
The two scenarios here are
model = torch.compile(model)
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
vs
@torch.compile
def train_step(...):
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
for input, target in data:
train_step(...)
As the issue linked above mentions, theoretically the bigger the extent of compilation the better, but the real results may vary.
In my experiments, compiling the whole train step functions sometimes yields improved performance.
However, the docs and tutorials mainly refer to compiling the module, or just a standalone function (not a training step).
Is compiling the whole step recommended / supported? Is there official guidance on the best extent of compilation? (In my experience, including dataloading gets messy).
Thanks in advance!