Graph breaks or No graph breaks in torchdynamo

I have read some introductions about torch dynamo. It can emit multiple sub-graphs (graph breaks) and one graph without any breaks.

I am curious about why it still produces multiple sub-graphs if it can generate the entire graph. What would be the sacrifice if we choose not to have any graph breaks?

Is it possible to explain it in more detail using the following example?


def func(x):
    if x.shape[0] > 10:
        return x.cos()
    return x.sin()

many thx.

Graph breaks mean performance hit so typically you want as few of them as possible, in your example you can’t produce a full graph because of control flow on the shape of a tensor

You can workaround this if your compiler supports dynamic shapes which torch.compile(..., dynamic=True) will let you do but the perf speedups you’ll get will be less drastic than making your model not dynamic

Also a great tool to understand graph breaks in your model is torch._dynamo.explain()

Where can one find documentation for this function?

The code XD https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/eval_frame.py#L708

Otherwise the new torch logging system is similar and better documented torch._logging — PyTorch 2.1 documentation and has a graph break explainer

2 Likes