I wonder what happens when I compile a nested if phrase
Actually I tried and it seems that it doesn’t work as I expected
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lin1 = nn.Linear(4,3)
self.lin2 = nn.Linear(3,3)
self.lin3 = nn.Linear(3,3)
self.relu = nn.ReLU()
def forward(self, x):
x = self.lin1(x)
sumval = torch.sum(x)
if sumval > 2.0:
out = self.relu(x)
againsumval = torch.sum(x)
if againsumval > 3.0:
return self.lin2(x)
else:
return self.lin3(x)
else:
out = x
return out
new_model = torch.compile(backend=nonefunc)(model.forward)
new_model(torch.zeros((1,4)))
I decompiled a transformed code with depyf and here is the result
def __resume_at_96_2(self, x):
out = self.relu(x)
againsumval = torch.sum(x)
if againsumval > 3.0:
return self.lin2(x)
return self.lin3(x)
out = x
return out
def __resume_at_274_3(x):
out = x
return out
def __transformed_code_0_for_forward(self, x, num):
againsumval = None; out = None; sumval = None # this line helps Python to generate bytecode with at least the same number of local variables as the original function
__temp_2, __temp_3 = __compiled_fn_1(self._modules['lin1']._parameters[
'weight'], self._modules['lin1']._parameters['bias'], x)
x = __temp_2
if __temp_3:
return __resume_at_96_2(self, x)
return __resume_at_274_3(x)
compiled_fn takes a code generating a condition variables required for first “IF” expression
I thought torch dynamo divides code into pythonic part and non-pythonic part which can be compiled by some backend like triton or any-else customized one. But, it seems that “resume function” is not divided, then can backend compiler work for these code?
“resume” function is created in “create_call_resume_at” and interpreter does not call inside them. And that’s why “jump_break” is not called for the resume code, I guess.
I am curious if there is any way to divide “resume” function either?