In our use cases, I’d like to optimize the backward graph only and leave forward graph running in eager mode. Is it possible to add an option in torch.compile like this?
torch._dynamo.config.compiled_autograd = True
@torch.compile(bwd_only=True, ...)
class MyModule(nn.Module):
....
We actually have an option for compiling only the backward graph - compiled autograd. It’s still experimental, but you can try it out with something like:
# run the fw in eager (or with compile)
out = model(inp)
loss = loss_fn(out, ...)
# take the entire backward graph, and compile it
with torch._dynamo.utils.maybe_enable_compiled_autograd(True, fullgraph=True, dynamic=False):
loss.backward()
It’s for LLM training. In our code base, torch.compile will fail for an individual Transformer module since our forward logic is too complex. While what we really want to optimize is the backward of Transformer block since there are plenty of opportunities to overlap communication and computation in the backward. We are considering to implement it automatically. And I believe capturing the backward graph only will be easier than capturing the forward graph.
Got it. I can’t give much more specific advice without more detail, but a few things you might find useful are:
Custom ops: if there is a region of your forward that is well optimized (too complicated to naively capture with torch.compile, and compile won’t be able to speed it up much anyway, you can wrap it in a custom operator. Torch.compile will compile everything around it, but treat it as a black box that gets invoked at runtime. Python Custom Operators — PyTorch Tutorials 2.4.0+cu121 documentation
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
In this way, it can be successfully captured by dynamo, but got errors when running the model. Seems that dynamo can not handle random state well.
While in backward, the rng code will not appear in autograd backward graph, only the backward of dropout will. That’s why I believe only capturing backward is easier than capturing both forward and backward graph.