Onnx export with dynamo using torch.cond for dynamic models

:bug: Describe the bug

Environment
pytorch.version = 2.5.1+cu124’

Description
I am trying to implement a dummy example of a model whose forward method operations would depend on some intermediate calculation on the input. The final goal of this is to see if I can export such a model to ONNX.

The model in question is the following:

class TwoLayerNetDynamic(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(TwoLayerNetDynamic, self).__init__()
        self.model_name = 'TwoLayerNetDynamic'

        self.fully_connected_1 = nn.Linear(input_size, hidden_size)
        self.early_exit_head_1 = nn.Linear(hidden_size, output_size)

        self.threshold = torch.tensor([0.0], dtype=torch.float32).to(DEVICE)
        self.last_exit = nn.Linear(hidden_size, output_size)

        self.training_exits = False

        print("TwoLayerNetDynamic initialized")

    def forward(self, x: torch.Tensor):
        mean = x.mean()
        
        # version 1 
        # if torch.gt(mean, self.threshold):
        #     x = self.early_exit_head_1(x)
        # else:
        #     x = self.last_exit(x)
        # x = torch.cat([x, mean.reshape_as(x)], dim=1)

        # version 2
        x = self.fully_connected_1(x)
        x = fc.cond(mean>0.0,self.early_exit_head_1,self.last_exit,(x,))
        x = torch.cat([x, mean.reshape_as(x)], dim=1)       
        return x 

Using the commented version1 of forward, I used torch.jit_script as mentioned here.
I manage to export a onnx model that has the dynamic behavior I want doing the following:

model = TwoLayerNetDynamic(input_size=1, hidden_size=3, output_size=1)

script_module = torch.jit.script(model)
    torch.onnx.export(
        model=script_module,
        args=_x,
        f=f"./models/onnx/{model.model_name}_scripting.onnx"
)

I at least think this is the case because when I see the exported model in Netron i see the flow-control statement captured:
image

So, what is the problem?
When I try doing the same thing with the new API using Dynamo, it gives me an error.
The procedure there is this:

### Using TorchDynamo ###
filename = f""
onnx_filepath = f"./models/onnx/{model.model_name}_dynamo.onnx"
onnx_program:ONNXProgram = torch.onnx.export(
  model=model,
  args=(_x,),
  dynamo=True,
  report=True
)
onnx_program.save(onnx_filepath)

When doing this, I get a lenghty error, but also the markdown report for the export. Th error is:

Traceback (most recent call last):

  File "/home/iony/miniconda3/envs/lgvit/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 553, in _add_nodes
    _handle_call_function_node_with_lowering(

  File "/home/iony/miniconda3/envs/lgvit/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 444, in _handle_call_function_node_with_lowering
    raise _errors.DispatchError(

torch.onnx._internal.exporter._errors.DispatchError: No ONNX function found for <torch._higher_order_ops.cond.CondOp object at 0x7f87bb7a7a30>. Failure message: No decompositions registered for the real-valued input


The above exception was the direct cause of the following exception:


Traceback (most recent call last):

  File "/home/iony/miniconda3/envs/lgvit/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 1134, in export
    onnx_program = _exported_program_to_onnx_program(

  File "/home/iony/miniconda3/envs/lgvit/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 791, in _exported_program_to_onnx_program
    values = _add_nodes(exported_program, model, lower=lower, registry=registry)

  File "/home/iony/miniconda3/envs/lgvit/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 565, in _add_nodes
    raise _errors.ConversionError(

torch.onnx._internal.exporter._errors.ConversionError: Error when translating node %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%addmm, %p_early_exit_head_1_bias, %p_early_exit_head_1_weight, %p_last_exit_bias, %p_last_exit_weight)), kwargs = {}). See the stack trace for more information.

And from what I can see the torch.cond function is not supported yet.
Searching online I came across this PR to pytorch.
So it looks like this High-level Operations (HOPs) are supported.

My question is, am I doing something wrong?
Any insights?

Environment

PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.9.20 (main, Oct 3 2024, 07:27:41) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-126-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY