Enable torch.compile only for backward

Hi,

In our use cases, I’d like to optimize the backward graph only and leave forward graph running in eager mode. Is it possible to add an option in torch.compile like this?

torch._dynamo.config.compiled_autograd = True
@torch.compile(bwd_only=True, ...)
class MyModule(nn.Module):
    ....

We actually have an option for compiling only the backward graph - compiled autograd. It’s still experimental, but you can try it out with something like:

# run the fw in eager (or with compile)
out = model(inp)
loss = loss_fn(out, ...)

# take the entire backward graph, and compile it
with torch._dynamo.utils.maybe_enable_compiled_autograd(True, fullgraph=True, dynamic=False):
    loss.backward()

Ed also has a podcast on it if you want to learn more about it: https://pytorch-dev-podcast.simplecast.com/episodes/compiled-autograd-TCcEyBRZ

@bdhirsh , thanks for your reply!

A follow up question, do you know how to compile the backward of a single nn.Module, not the entire backward graph?

torch.compiling that individual module will be easiest, although it will result in your forward also getting compiled.

Out of curiosity - what use case do you have that requires wanting to keep everything in eager mode, except a specific region of the backward?

It’s for LLM training. In our code base, torch.compile will fail for an individual Transformer module since our forward logic is too complex. While what we really want to optimize is the backward of Transformer block since there are plenty of opportunities to overlap communication and computation in the backward. We are considering to implement it automatically. And I believe capturing the backward graph only will be easier than capturing the forward graph.

Got it. I can’t give much more specific advice without more detail, but a few things you might find useful are:

Custom ops: if there is a region of your forward that is well optimized (too complicated to naively capture with torch.compile, and compile won’t be able to speed it up much anyway, you can wrap it in a custom operator. Torch.compile will compile everything around it, but treat it as a black box that gets invoked at runtime. Python Custom Operators — PyTorch Tutorials 2.4.0+cu121 documentation

Functional collectives. Compile has some support for compiling comms operator’s, if you want to let the compiler try optimizing them. [RFC] PT2-Friendly Traceable, Functional Collective Communication APIs · Issue #93173 · pytorch/pytorch · GitHub

Sometimes even python custom operators can not solve the problem. A good example is Megatron-LM/megatron/legacy/model/transformer.py at main · NVIDIA/Megatron-LM · GitHub, where I tried to wrap the code into a python custom op.

with tensor_parallel.get_cuda_rng_tracker().fork():
    attention_probs = self.attention_dropout(attention_probs)

In this way, it can be successfully captured by dynamo, but got errors when running the model. Seems that dynamo can not handle random state well.

While in backward, the rng code will not appear in autograd backward graph, only the backward of dropout will. That’s why I believe only capturing backward is easier than capturing both forward and backward graph.