Workaround for unsupported return of multiple tensors for torch.cond in a model intended for torch.onnx.export(dynamo=true,...)

Trying to onnx.export a nn.Module with a conditional in its computational graph. In essence similar to this example:

import torch

class Wrapper(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cond_model = CondModel()
    
    def forward(self, x):
        nt = self.cond_model(x)
        return nt

class CondModel(torch.nn.Module):
    def forward(self, x):
        
        def true_fn(x,z):
            x = x + 1.0
            z = z * 0.0
            return x,z

        def false_fn(x,z):
            x = x - 1.0
            z = z * 1.0
            return x,z

        z = torch.rand(x.shape)
        nt = torch.cond(x.sum() > 0, true_fn, false_fn, [x,z])
        return nt

As per the documentation, the return from torch.cond must be a single tensor. Is there a dirty workaround that allows to get multiple tensors from the return?
I tried using nested tensors:

def true_fn(x,z):
   x = x + 1.0
   z = z * 0.0
   nt = torch.nested.nested_tensor([x,z], layout=torch.jagged)
   return nt

But compile fails at validation of the .shape of the return tensors (.shape in NestedTensors loses precise meaning):

torch._dynamo.exc.Unsupported: Expect branches to return tensors with same metadata but find pair[0] differ in 'shape: torch.Size([2, s1]) vs torch.Size([2, s2])', 'stride: (s1, 1) vs (s2, 1)', where lhs is TensorMetadata(shape=torch.Size([2, s1]), dtype=torch.float32, requires_grad=False, stride=(s1, 1), memory_format=None, is_quantized=False, qparams={}) and rhs is TensorMetadata(shape=torch.Size([2, s2]), dtype=torch.float32, requires_grad=False, stride=(s2, 1), memory_format=None, is_quantized=False, qparams={})
**Full traceback here**
Traceback (most recent call last):
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 55, in graph_break_as_hard_error
  return fn(*args, **kwargs)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 906, in call_function
  unimplemented(
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/exc.py", line 356, in unimplemented
  raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Expect branches to return tensors with same metadata but find pair[0] differ in 'shape: torch.Size([2, s1]) vs torch.Size([2, s2])', 'stride: (s1, 1) vs (s2, 1)', where lhs is TensorMetadata(shape=torch.Size([2, s1]), dtype=torch.float32, requires_grad=False, stride=(s1, 1), memory_format=None, is_quantized=False, qparams={}) and rhs is TensorMetadata(shape=torch.Size([2, s2]), dtype=torch.float32, requires_grad=False, stride=(s2, 1), memory_format=None, is_quantized=False, qparams={})

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/iony/DTU/f24/thesis/code/early_exit_vit/simple/example_conditional.py", line 34, in <module>
  result = model(input_tensor)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
  return forward_call(*args, **kwargs)
File "/home/iony/DTU/f24/thesis/code/early_exit_vit/simple/example_conditional.py", line 9, in forward
  nt = self.cond_model(x)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
  return forward_call(*args, **kwargs)
File "/home/iony/DTU/f24/thesis/code/early_exit_vit/simple/example_conditional.py", line 28, in forward
  nt = torch.cond(x.sum() > 0, true_fn, false_fn, [x,z])
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_higher_order_ops/cond.py", line 201, in cond
  return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 576, in _fn
  return fn(*args, **kwargs)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1406, in __call__
  return self._torchdynamo_orig_callable(
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 566, in __call__
  return _compile(
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 1006, in _compile
  guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 734, in compile_inner
  return _compile_inner(code, one_graph, hooks, transform)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
  return function(*args, **kwargs)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 769, in _compile_inner
  out_code = transform_code_object(code, transform)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1402, in transform_code_object
  transformations(instructions, code_options)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 237, in _fn
  return fn(*args, **kwargs)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 681, in transform
  tracer.run()
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2906, in run
  super().run()
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1076, in run
  while self.step():
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 986, in step
  self.dispatch_table[inst.opcode](self, inst)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 683, in wrapper
  return inner_fn(self, inst)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1763, in CALL_FUNCTION_EX
  self.call_function(fn, argsvars.items, kwargsvars)
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 921, in call_function
  self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 58, in graph_break_as_hard_error
  raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.

from user code:
 File "/home/iony/miniconda3/envs/eevit/lib/python3.9/site-packages/torch/_higher_order_ops/cond.py", line 193, in _cond_op_wrapper
  return cond_op(*args, **kwargs)

Is the feature not implemented, even in a nightly? Or is there another workaround that might work if I intend to operate only for inference?

Error logs

No response

Versions

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] onnx==1.17.0
[pip3] onnxruntime==1.19.2
[pip3] onnxscript==0.1.0.dev20241226
[pip3] torch==2.6.0.dev20241226+cu124

Solved here