I am working on whole graph optimize, I have read dynamo source code , where there are:
- generic_jump
- break_graph_if_unsupported
- step_unsupported
- store_attr
- return_value
subgraph will be compiled,but I want to know how those graph link to each other, how to get this infomation?
example code:
class MyModule(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.activation = nn.GELU()
def forward(self, x):
x = self.linear(x)
x = self.activation(x)
return x
class PararllelModel(nn.Module):
def __init__(self, input_dim, hidden_size, output_dim):
super().__init__()
self.embed0 = nn.Linear(input_dim, input_dim)
self.embed1 = nn.Linear(input_dim, input_dim)
self.block0 = MyModule(input_dim, hidden_size)
self.block1 = MyModule(input_dim, hidden_size)
self.linar = nn.Linear(hidden_size * 2, output_dim)
def forward(self, x):
x0 = self.embed0(x)
x1 = self.embed1(x)
if self.training:
x0 = CheckpointFunction.apply(self.block0, False, x0)
x1 = CheckpointFunction.apply(self.block1, False, x1)
else:
x0 = self.block0(x0)
x1 = self.block1(x1)
x = torch.concat([x0, x1], dim=-1)
out = self.linar(x)
return out
if __name__ == "__main__":
device = "cpu" # torch.cuda.current_device()
mod = PararllelModel(10, 12, 20)
x = torch.randn(16, 10, device=device)
(explanation,
out_guards,
graphs,
ops_per_graph,
break_reasons,
explanation_verbose) = torch._dynamo.explain(mod, x)
print(f"there have {len(graphs)} graph")