Let HOF (torch.cond, torch.while_loop) skip compilation of input functions if needed

Hi all! Great work with the HOFs :slight_smile:

For debugging, it would be nice if there was a switch I could flip to let

torch.while_loop(
    cond_fn,
    body_fn,
    carried_inputs,
)

behave exactly like the python equivalent

val = carried_inputs
while cond_fn(*val):
    val = body_fn(*val)

I know the behavior after compilation is the same, but I’d like to skip the compilation of cond_fn and body_fn because

  • then I could set breakpoints inside body_fn and play around with values.
  • body_fn takes a lot of time to compile and in unit tests I’d prefer to have the eager version, which runs faster for a single usage.

Is there a way to switch between these behaviors? Currently I’m patching them in the unit tests but I wonder if there’s a pytorch-native way to do so.

Cheers,

Lucas