How to get whole graph when subgraph_compile triggle

I am working on whole graph optimize, I have read dynamo source code , where there are:

  1. generic_jump
  2. break_graph_if_unsupported
  3. step_unsupported
  4. store_attr
  5. return_value
    subgraph will be compiled,but I want to know how those graph link to each other, how to get this infomation?
    example code:
class MyModule(nn.Module):
    def __init__(self, input_dim, output_dim):
        self.linear = nn.Linear(input_dim, output_dim)
        self.activation = nn.GELU()

    def forward(self, x):

        x = self.linear(x)
        x = self.activation(x)
        return x

class PararllelModel(nn.Module):
    def __init__(self, input_dim, hidden_size, output_dim):
        self.embed0 = nn.Linear(input_dim, input_dim)
        self.embed1 = nn.Linear(input_dim, input_dim)

        self.block0 = MyModule(input_dim, hidden_size)
        self.block1 = MyModule(input_dim, hidden_size)
        self.linar = nn.Linear(hidden_size * 2, output_dim)

    def forward(self, x):
        x0 = self.embed0(x)
        x1 = self.embed1(x)
            x0 = CheckpointFunction.apply(self.block0, False, x0)
            x1 = CheckpointFunction.apply(self.block1, False, x1)
            x0 = self.block0(x0)
            x1 = self.block1(x1)
        x = torch.concat([x0, x1], dim=-1)
        out = self.linar(x)
        return out
if __name__ == "__main__":
    device = "cpu"  # torch.cuda.current_device()
    mod = PararllelModel(10, 12, 20)
    x = torch.randn(16, 10, device=device)
            explanation_verbose) = torch._dynamo.explain(mod, x)
        print(f"there have {len(graphs)} graph")

It sounds like you probably want to set torch.compile(fullgraph=True) to hard error and get a single graph out or take a look at torch._dynamo.export() which is still very early days

Regardless if you want to visually inspect the multiple subgraphs

What you can do is run

import torch.nn as nn

class MLP(nn.Module):
  def __init__(self):
    self.fc1 = nn.Linear(32, 64)

  def forward(self, x):
    x = self.fc1(x)
    print("IM A GRAPH BREAK")
    x = torch.nn.functional.gelu(x)
    return x

model = MLP()

batch_size = 8
input = torch.randn(batch_size, 32)

import torch._dynamo
from torch._functorch.aot_autograd import aot_module_simplified

def toy_backend(gm, sample_inputs):

    return gm.forward

fn = torch.compile(backend=toy_backend, dynamic=True)(model)

# triggers compilation of forward graph on the first run
out = fn(input)

Will output the first and second graph one after the other and that’s how you should stitch them together

I’am alse working on geting the whole graph, any updates? @zjjott thanks.