Why user should specify cond function for conditional control flow in PyTorch 2.0?

Recently, I have read torch.export tutorial, and I am wondering, why if currently PyTorch 2.0 uses Frame Capture Feature of CPython, we still need to use special functions for exporting code (like from functorch.experimental.control_flow import cond) ??
Isn’t it would be the same problem as was with torch.script ?
We’ll duplicate code and decrease readability !!

Why torch.export cannot do it by itself using Frame Capture Feature of CPython ??

1 Like

I’m not sure what is the rationale in Pytorch, but TensorFlow went this route and it’s not as clean as one could expect. Whenever TF encounters a conditional, it looks whether it’s a tensor or python variable to treat it as either variable or “compilation constant” respectively. For tensors it generates a cond blocks in the DAG like with torch.jit.script, but for constants it just traces the selected branch like with torch.jit.exoprt.

Seems simple enough when you have just one function, but once you consider a whole program, where exported functions are calling other exported functions, it becomes quite difficult to track what behavior would be triggered in each situation and some corner cases get really tricky. In those cases, you could wish you had a mechanism to force function to be either traced or scripted.

tl;dr - it’s probably a design decision.

I think it has much simpler explanation: torh.compile is not fully support all the corner cases and this particular comer case was not handled

Maybe, but it doesn’t look like a corner case to me. Conditional statements are rather basic constructs.