Onnx Export problem when using Torchscript beforehand

Hello,

I have a vision model based on DEIM-DFINE. I can convert this model into ONNX (with tracing) and I can convert it to Torchscript, but when I try to convert the Torchscript model to ONNX, it fails.

I need one control flow in the final ONNX model, which is why I can’t use tracing, I need the torchscript conversion before.

However for some reason, ONNX is not able to infer shapes, even when I hardcode them. This leads to continuous errors in many parts of the model and I have no idea how to resolve any of them, or if it is even possible. The model itself would operate always on the same input and output dimensions, so it is quite basic regarding shapes. I could literally in most layers define the shapes by hand, but I see no way this is possible.

I tried several things for the error message below. Converting to float, hardcoding shapes like this for example:

shape_tensor = torch.tensor([1, 256, 15, 15], dtype=torch.int64) 
value_att = self.norm1(value_att.view(shape_tensor)
torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator group_norm, unknown input rank. Please feel free to request 
support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues
  [Caused by the value ‘1182 defined in (%1182 : FloatTensor(device=cpu) = onnx::Castto=1, scope: engine.deim.deim.DEIM::/
engine.deim.hybrid_encodereff.HybridEncoderEff::encoder/torch.nn.modules.container.Sequential::pre_small/engine.deim.hybrid_encodereff.LowFormerEncoderLayer::0 
# Detec/tools/deployment/../../engine/deim/hybrid_encodereff.py:782:31  
)’ (type ‘Tensor’) in the TorchScript graph. The containing node has kind ‘onnx::Cast’.]
(node defined in   File “Detec/tools/deployment/../../engine/deim/hybrid_encodereff.py”, line 782   
# value_att = value_att# shape_tensor = torch.tensor([1, 256, 15, 15], dtype=torch.int64)value_att = self.norm1(value_att.float())      
                                                                                    ~~~~~~~~~~~~~~~ <— HERE      
)Inputs:#0: value_att.3 defined in (%value_att.3 : FloatTensor(device=cpu) = onnx::Add(%1179, %1180), 
scope: engine.deim.deim.DEIM::/engine.deim.hybrid_encodereff.HybridEncoderEff::encoder/torch.nn.modules.container.Sequential::
pre_small/engine.deim.hybrid_encodereff.LowFormerEncoderLayer::0 # Detec/tools/

But nothing worked and it seems these hardcoded shapes are thrown away during torchscript conversion, so onnx conversion can’t access them.

Is there a way to create an onnx model with simple control flow from a pytorch model without torchscript? Can I fix the torchscript model, such that onnx conversion will work? The whole model is unfortunately quite big.

Thanks in advance!

(for some reason the formatting of the error message is messed sry for that!)

Some other error message more refering to the problems I talk about.

torch.onnx.errors.SymbolicValueError: Cannot determine scalar type for this '<class 'torch.TensorType'>' instance and a default value was not provided.
  [Caused by the value '1182 defined in (%1182 : Tensor(*, *, *, *) = onnx::Reshape[allowzero=0](%value_att.3, %1181),
   scope: engine.deim.deim.DEIM::/engine.deim.hybrid_encodereff.HybridEncoderEff::encoder/torch.nn.modules.container.Sequential::pre_small/engine.deim.hybrid_encodereff.LowFormerEncoderLayer::0 
   # Detec/tools/deployment/../../engine/deim/hybrid_encodereff.py:782:31
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Reshape'.]
(node defined in   File "Detec/tools/deployment/../../engine/deim/hybrid_encodereff.py", line 782
# value_att = value_att 
shape_tensor = torch.Size([1, 256, 15, 15])  
value_att = self.norm1(value_att.view(shape_tensor))
                        ~~~~~~~~~~~~~~ <--- HERE) 
Inputs:#0: value_att.3 defined in (%value_att.3 : Tensor = onnx::Add(%1124, %value_att), 
scope: engine.deim.deim.DEIM::/engine.deim.hybrid_encodereff.HybridEncoderEff::encoder/torch.nn.modules.container.Sequential::pre_small/engine.deim.hybrid_encodereff.LowFormerEncoderLayer::0 
# Detec/tools/deployment/../../engine/deim/hybrid_encodereff.py:779:20)  (type 'Tensor') 
#1: 1181 defined in (%1181 : Long(4, strides=[1], device=cpu) = onnx::Constant[value=   1  256   15   15 [ CPULongType{4} ]](), 
scope: engine.deim.deim.DEIM::/engine.deim.hybrid_encodereff.HybridEncoderEff::encoder/torch.nn.modules.container.Sequential::pre_small/engine.deim.hybrid_encodereff.LowFormerEncoderLayer::0 
# Detec/tools/deployment/../../engine/deim/hybrid_encodereff.py:782:31 )  (type 'Tensor') 
Outputs: #0: 1182 defined in (%1182 : Tensor(*, *, *, *) = onnx::Reshape[allowzero=0](%value_att.3, %1181), 
scope: engine.deim.deim.DEIM::/engine.deim.hybrid_encodereff.HybridEncoderEff::encoder/torch.nn.modules.container.Sequential::pre_small/engine.deim.hybrid_encodereff.LowFormerEncoderLayer::0 
# Detec/tools/deployment/../../engine/deim/hybrid_encodereff.py:782:31)  (type 'Tensor')

Update: I can run inputs through the torchscript model, but I can’t save it or do anything else with it.

If I save it with torch.jit.save this error comes:

RuntimeError: strides() called on an undefined Tensor

Same error comes when I want to print the code:

print(model_scripted.code)

Tensorrt conversion also fails.

However never the less a model is saved from torchscript, but it can’t be loaded, as this error comes:

RuntimeError: PytorchStreamReader failed locating file constants.pkl: file not found

The error with the strides() was solved by adding “Optional” annotations everywhere where a potential None could be. In parameters and variables!

The last line is simply torchscript saving the model, even though the process failed. I wonder why it does that.

The ONNX error from the top seems to be a severe deficiency in torchscript and ONNX communication. The minimal example, resulting in the same error would be this:

class MyTestBB(nn.Module):
    
    def __init__(self):
        super().__init__()

        channels = [3,60,120,240,280, 320]
        convs = [nn.Conv2d(channels[i],channels[i+1], kernel_size=3, stride=2) for i in range(5)]
        self.convs = nn.ModuleList(convs)
        
    def forward(self, x):
        outs = []
        for ind, layer in enumerate(self.convs):
            x = layer(x)
            if ind>2:
                outs.append(x)
        return outs[0], outs[1]


class MyTestModel(nn.Module):
    
    def __init__(self, cfg):
        super().__init__()
        self.norm1 = nn.GroupNorm(1,280) # 384
        self.backbone = MyTestBB()
        
    def forward(self, images: torch.Tensor, sth: Optional[list[dict[str, torch.Tensor]]]=None):
        outputs = self.backbone(images)
        listx: torch.Tensor = outputs[0]
        listx = self.norm1(listx)
        return listx, listx

If you convert MyTestModel() first to torchscript and then to onnx, ONNX is unable to use GroupNorm or LayerNorm.

torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator group_norm, unknown input rank.

It can only be solved by removing any kind of lists. Even when the backbone only returns a single item from the list with hardcoded indexing, groupnorm fails.

It can for example be solved by converting the list to a tuple and only change the content of arrays:

    def forward(self, x):
        outs = (torch.randn(1,60,240,240),torch.randn(1,120,120,120),torch.randn(1,240,60,60),torch.randn(1,280,30,30),torch.randn(1,320,15,15))
        for ind, layer in enumerate(self.convs):
            x = layer(x)
            if ind>2:
                outs[ind][:] = x
        return outs[3], outs[4]

Is there no alternative? Can we please fix that?

Please use the dynamo=True option and try again?

That won’t work unfortunately. Dynamo only allows tracing. I need scripting to create one control flow in the final onnx model.

Would it be possible to rewrite with torch.cond?

No. Tracing is tracing.

After a long and painful journey, I achieved conversion from torch→torchscript→onnx→tensorrt.

Basically you have to eliminate any kind of lists and hardcode everything with tuples, which you also annotate like: tuple[torch.Tensor, torch.Tensor,torch.Tensor].

So if your tuple has size 2, you have to create a new model class for that with the new annotation of tuple[torch.Tensor, torch.Tensor]. There is no other option. Don’t try it, you will suffocate with weird errors that make no sense and don’t pinpoint to that problem xD. (the main problem is the torchscript to onnx conversion).

After finally retrieving the converted model, I see that cuda-optimizations completely fail on it, because changing the shapes, depending on the input content (not the input shape, like dynamic shapes like batchsize), breaks cuda_graph optimizations. I guess you can include a condition in your model, but this condition shouldn’t change the shape of a tensor, that is further processed by layers. The resulting model was far slower than just using ONNX (tracing) or Torchscript (scripting, but no condition included).

I hope anybody attempting this will read this post and just don’t do it. At least when efficiency is your concern.

Even if efficiency is not your concern, I would recommend splitting the model into submodels and executing them depending on the respective output of the previous submodel. It is faster and easier than trying to put it into one model.

The final model I got is slower on Jetson Devices and Nvidia GPU. Every shape must be clear before execution by knowing the input resolution, otherwise it will be slow.

The biggest efficiency crash came from tensorrt on the jetson devices, being 2x as slow compared to the traced model (no if). The onnx or torchscript model was at least ~20% slower.