Describe the bug
Environment
pytorch.version = 2.5.1+cu124’
Description
I am trying to implement a dummy example of a model whose forward method operations would depend on some intermediate calculation on the input. The final goal of this is to see if I can export such a model to ONNX.
The model in question is the following:
class TwoLayerNetDynamic(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(TwoLayerNetDynamic, self).__init__()
self.model_name = 'TwoLayerNetDynamic'
self.fully_connected_1 = nn.Linear(input_size, hidden_size)
self.early_exit_head_1 = nn.Linear(hidden_size, output_size)
self.threshold = torch.tensor([0.0], dtype=torch.float32).to(DEVICE)
self.last_exit = nn.Linear(hidden_size, output_size)
self.training_exits = False
print("TwoLayerNetDynamic initialized")
def forward(self, x: torch.Tensor):
mean = x.mean()
# version 1
# if torch.gt(mean, self.threshold):
# x = self.early_exit_head_1(x)
# else:
# x = self.last_exit(x)
# x = torch.cat([x, mean.reshape_as(x)], dim=1)
# version 2
x = self.fully_connected_1(x)
x = fc.cond(mean>0.0,self.early_exit_head_1,self.last_exit,(x,))
x = torch.cat([x, mean.reshape_as(x)], dim=1)
return x
Using the commented version1 of forward
, I used torch.jit_script
as mentioned here.
I manage to export a onnx model that has the dynamic behavior I want doing the following:
model = TwoLayerNetDynamic(input_size=1, hidden_size=3, output_size=1)
script_module = torch.jit.script(model)
torch.onnx.export(
model=script_module,
args=_x,
f=f"./models/onnx/{model.model_name}_scripting.onnx"
)
I at least think this is the case because when I see the exported model in Netron i see the flow-control statement captured:
So, what is the problem?
When I try doing the same thing with the new API using Dynamo, it gives me an error.
The procedure there is this:
### Using TorchDynamo ###
filename = f""
onnx_filepath = f"./models/onnx/{model.model_name}_dynamo.onnx"
onnx_program:ONNXProgram = torch.onnx.export(
model=model,
args=(_x,),
dynamo=True,
report=True
)
onnx_program.save(onnx_filepath)
When doing this, I get a lenghty error, but also the markdown report for the export. Th error is:
Traceback (most recent call last):
File "/home/iony/miniconda3/envs/lgvit/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 553, in _add_nodes
_handle_call_function_node_with_lowering(
File "/home/iony/miniconda3/envs/lgvit/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 444, in _handle_call_function_node_with_lowering
raise _errors.DispatchError(
torch.onnx._internal.exporter._errors.DispatchError: No ONNX function found for <torch._higher_order_ops.cond.CondOp object at 0x7f87bb7a7a30>. Failure message: No decompositions registered for the real-valued input
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/iony/miniconda3/envs/lgvit/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 1134, in export
onnx_program = _exported_program_to_onnx_program(
File "/home/iony/miniconda3/envs/lgvit/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 791, in _exported_program_to_onnx_program
values = _add_nodes(exported_program, model, lower=lower, registry=registry)
File "/home/iony/miniconda3/envs/lgvit/lib/python3.9/site-packages/torch/onnx/_internal/exporter/_core.py", line 565, in _add_nodes
raise _errors.ConversionError(
torch.onnx._internal.exporter._errors.ConversionError: Error when translating node %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%addmm, %p_early_exit_head_1_bias, %p_early_exit_head_1_weight, %p_last_exit_bias, %p_last_exit_weight)), kwargs = {}). See the stack trace for more information.
And from what I can see the torch.cond
function is not supported yet.
Searching online I came across this PR to pytorch.
So it looks like this High-level Operations (HOPs) are supported.
My question is, am I doing something wrong?
Any insights?
Environment
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.9.20 (main, Oct 3 2024, 07:27:41) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-126-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY