Can toch.cond be used recursively?

I am wondering if torch.cond can be used recursively to define something a bit more complicated than a simple if/else function call as shown in the docs. Is there any red flag for implementing arbitrary logic with it?

I have a recursive call that seemed to be working fine as I developed it, but then it started ignoring true conditions and following the wrong chain (unless I have a bug somewhere). I can try to make a repro if needed, but first I thought it would be good to ask if recursion is a no-no.

Thanks

If you’re able to create a small repro + file an issue on GitHub that would be helpful! I wouldn’t be surprised if this doesn’t work today (we handle cond in dynamo by statically tracing out subgraphs for the true and false branches, I’m not sure off the top of my head how that tracing step will handle the case where we are recursively calling back into the same function that invokes the cond)

Ok, thanks. I will try to make a small repro soon. For now, it appeared to me that the recursion caused an error in that the condition in the recursion was already decided from a past recursion step so that something like this…

x = torch.tensor(0)
def func(x):
  if x < 2:
    return func(x + 1)
  return foo(x)

x was always evaluated to be < 2 because this was true on the first iteration, and when the value surpasses 2, it still executed the if block