Does dynamo not support nested if?

I wonder what happens when I compile a nested if phrase

Actually I tried and it seems that it doesn’t work as I expected

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.lin1 = nn.Linear(4,3)
        self.lin2 = nn.Linear(3,3)
        self.lin3 = nn.Linear(3,3)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.lin1(x)
        sumval = torch.sum(x)
        if sumval > 2.0:
            out = self.relu(x)
            againsumval = torch.sum(x)
            if againsumval > 3.0:
                return self.lin2(x)
            else:
                return self.lin3(x)
        else:
            out = x
        return out

new_model = torch.compile(backend=nonefunc)(model.forward)
new_model(torch.zeros((1,4)))

I decompiled a transformed code with depyf and here is the result

def __resume_at_96_2(self, x):
    out = self.relu(x)
    againsumval = torch.sum(x)
    if againsumval > 3.0:
        return self.lin2(x)
    return self.lin3(x)
    out = x
    return out

def __resume_at_274_3(x):
    out = x
    return out

def __transformed_code_0_for_forward(self, x, num):
    againsumval = None; out = None; sumval = None # this line helps Python to generate bytecode with at least the same number of local variables as the original function
    __temp_2, __temp_3 = __compiled_fn_1(self._modules['lin1']._parameters[
        'weight'], self._modules['lin1']._parameters['bias'], x)
    x = __temp_2
    if __temp_3:
        return __resume_at_96_2(self, x)
    return __resume_at_274_3(x)

compiled_fn takes a code generating a condition variables required for first “IF” expression

I thought torch dynamo divides code into pythonic part and non-pythonic part which can be compiled by some backend like triton or any-else customized one. But, it seems that “resume function” is not divided, then can backend compiler work for these code?

“resume” function is created in “create_call_resume_at” and interpreter does not call inside them. And that’s why “jump_break” is not called for the resume code, I guess.

I am curious if there is any way to divide “resume” function either?

Compile generally supports nested conditionals. The problem in your code is that your if/else is branching on a data-dependent condition (if torch.sum(x) > 2.0). Dynamo will by default graph break on data-dependent control flow, giving you separate graphs for the regions before/after.

Thanks for answer

I know if there is a data-dependent condition, it breaks a graph. That’s where “jump_graph_break” function is called and there it calls “compile_subgraph” and “create_call_resume_at”. AFIK, what “create_call_resume_at” does is extract all instructions after-break and then make it as an code object (+ add call instruction on the main output instructions).

I would like to if I am able to recursively make “resume” function into subgraph.
In the example I posted, I expect __resume_at_96_2 to be divided again before-if and after-if like __transformed_code_0_for_forward.
I think it is not impossible as we can again call callback function(maybe?) with created “resume” function code object.

I think it would be much better for compiler not giving up break-after region.

I found it was because I put few data while doing compile

Compilation is a kind of based on tracing and thus it should call resume function to be compiled recursively. But, one branch body was hard to get into with the branch condition.

I found that nested if phrases are also compiled when there is data which goes into the branch