Having troubles when converting CurveNet PyTorch model to ONNX

Hi, I am trying to convert CurveNet model, which is .pth file, to ONNX file. But I can’t deal with it. Here are the steps I took:

  1. Download the CurveNet repo, and upload it to my Google Drive.

  2. Use colab with GPU to train the model and get ‘model.pth’

  3. Create a file contains files in the picture below:

‘curvenet_util.py’、‘walk.py’ and ‘models’ are files in CurveNet repo

‘pytorch2onnx.py’ is what I going to use to convert the ‘model.pth’ to ONNX file,the contents are below:

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from curvenet_util import *

# Define a convolution neural network
curve_config = {
        'default': [[100, 5], [100, 5], None, None],
        'long':  [[10, 30], None,  None,  None]
    }

class CurveNet(nn.Module):
    def __init__(self, num_classes=40, k=20, setting='default'):
        super(CurveNet, self).__init__()

        assert setting in curve_config

        additional_channel = 32
        self.lpfa = LPFA(9, additional_channel, k=k, mlp_num=1, initial=True)

        # encoder
        self.cic11 = CIC(npoint=1024, radius=0.05, k=k, in_channels=additional_channel, output_channels=64, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][0])
        self.cic12 = CIC(npoint=1024, radius=0.05, k=k, in_channels=64, output_channels=64, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][0])
        
        self.cic21 = CIC(npoint=1024, radius=0.05, k=k, in_channels=64, output_channels=128, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][1])
        self.cic22 = CIC(npoint=1024, radius=0.1, k=k, in_channels=128, output_channels=128, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][1])

        self.cic31 = CIC(npoint=256, radius=0.1, k=k, in_channels=128, output_channels=256, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][2])
        self.cic32 = CIC(npoint=256, radius=0.2, k=k, in_channels=256, output_channels=256, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][2])

        self.cic41 = CIC(npoint=64, radius=0.2, k=k, in_channels=256, output_channels=512, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][3])
        self.cic42 = CIC(npoint=64, radius=0.4, k=k, in_channels=512, output_channels=512, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][3])

        self.conv0 = nn.Sequential(
            nn.Conv1d(512, 1024, kernel_size=1, bias=False),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True))
        self.conv1 = nn.Linear(1024 * 2, 512, bias=False)
        self.conv2 = nn.Linear(512, num_classes)
        self.bn1 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=0.5)

    def forward(self, xyz):
        l0_points = self.lpfa(xyz, xyz)

        l1_xyz, l1_points = self.cic11(xyz, l0_points)
        l1_xyz, l1_points = self.cic12(l1_xyz, l1_points)

        l2_xyz, l2_points = self.cic21(l1_xyz, l1_points)
        l2_xyz, l2_points = self.cic22(l2_xyz, l2_points)

        l3_xyz, l3_points = self.cic31(l2_xyz, l2_points)
        l3_xyz, l3_points = self.cic32(l3_xyz, l3_points)
 
        l4_xyz, l4_points = self.cic41(l3_xyz, l3_points)
        l4_xyz, l4_points = self.cic42(l4_xyz, l4_points)

        x = self.conv0(l4_points)
        x_max = F.adaptive_max_pool1d(x, 1)
        x_avg = F.adaptive_avg_pool1d(x, 1)
        
        x = torch.cat((x_max, x_avg), dim=1).squeeze(-1)
        x = F.relu(self.bn1(self.conv1(x).unsqueeze(-1)), inplace=True).squeeze(-1)
        x = self.dp1(x)
        x = self.conv2(x)
        return x


# Create the CurveNet model by using the above model definition.
torch_model = CurveNet().cpu()
#print(torch_model)

model_path = 'C:/Users/chris/CompetitionMMwave/pytorchTraining/model.pth'


# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(torch.load(model_path, map_location=map_location))

# set the model to inference mode
torch_model.eval()
#print(torch_model)
#print("----------")

# Input to the model

x = torch.randn(32, 3, 1024, requires_grad=True, device='cpu')

print('x:',x)

#torch_out = torch_model(x)
#print(torch_out)

import torch.onnx 


#Function to Convert to ONNX 
def Convert_ONNX(): 

    # Export the model
    torch.onnx.export(torch_model,               # model being run
                    x,                         # model input (or a tuple for multiple inputs)
                    "model.onnx",   # where to save the model (can be a file or file-like object)
                    export_params=True,        # store the trained parameter weights inside the model file
                    opset_version=11,          # the ONNX version to export the model to
                    do_constant_folding=True,  # whether to execute constant folding for optimization
                    input_names = ['input'],   # the model's input names
                    output_names = ['output'], # the model's output names
                    dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                    'output' : {0 : 'batch_size'}})
    print(" ") 
    print('Model has been converted to ONNX')

Convert_ONNX()

Then it got errors below :

C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\cuda\__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  ..\c10\cuda\CUDAFunctions.cpp:100.) 
  return torch._C._cuda_getDeviceCount() > 0
x: tensor([[[-1.0610, -0.5267,  1.4753,  ..., -0.4559, -0.1569, -1.5430],
         [ 0.0778, -0.4320, -0.5244,  ...,  0.4428, -0.1741,  0.3060],
         [ 0.1440, -1.5503,  1.1274,  ..., -1.0515,  0.0273, -1.2841]],

        [[-0.7762,  0.2611, -1.1352,  ..., -1.3489,  1.0865,  0.4396],
         [-0.4481,  2.0752,  0.4898,  ...,  0.0846, -0.0680,  0.0815],
         [ 0.6015,  2.7668,  0.3792,  ...,  1.6718, -0.2551,  2.4854]],

        [[-1.4961,  0.6301,  0.2529,  ..., -1.5639,  0.3833, -0.6893],
         [-0.4424, -0.7423,  0.6153,  ..., -1.9716,  0.1808, -0.8182],
         [ 0.5502, -0.1931,  0.7892,  ..., -0.0410, -1.2228,  0.1584]],

        ...,

        [[-0.9705, -1.4239, -0.4263,  ...,  0.1071,  0.0304, -1.5994],
         [-1.3054, -0.8234, -0.7768,  ...,  0.6917, -0.3518, -0.1506],
         [-1.1786, -0.4557, -0.1489,  ..., -1.3476,  2.7490, -0.3241]],

        [[ 0.5510, -1.8050,  1.0268,  ..., -1.0423, -0.6780, -1.5962],
         [ 0.3624,  0.1122,  0.1071,  ...,  1.5958,  0.6209, -1.3937],
         [ 1.4254,  0.1170,  0.1670,  ..., -1.2565, -0.9526, -2.0219]],

        [[ 2.6874,  1.1777, -1.6811,  ..., -0.2441, -0.0778,  1.7057],
         [ 0.0602,  0.6392, -0.3591,  ...,  0.7949,  0.6193,  0.9046],
         [ 0.3005,  1.2581,  0.4919,  ...,  0.4723, -1.3764,  0.5513]]],
       requires_grad=True)
torch.Size([32, 3, 1024])
torch.Size([32, 16, 1024])
torch.Size([32, 16, 1024])
torch.Size([32, 32, 1024])
torch.Size([32, 32, 1024])
torch.Size([32, 64, 256])
torch.Size([32, 64, 256])
torch.Size([32, 128, 64])
torch.Size([32, 128, 64])
tensor([[ 8.0885, -4.3205, -0.3741,  ..., -2.1033, -0.7325, -2.5036],
        [10.6468, -3.3417,  2.4449,  ..., -4.2334,  0.5905, -3.7904],
        [14.5540, -5.7213,  3.9037,  ..., -2.1741,  2.3348, -2.4426],
        ...,
        [ 9.1924, -3.6008,  3.4110,  ..., -2.6812,  0.3097, -3.6180],
        [11.4883, -5.2502,  0.2312,  ..., -2.5288, -0.6245, -2.4642],
        [ 9.1515, -4.6202,  1.6596,  ..., -1.7074, -0.9047, -3.9957]],
       grad_fn=<AddmmBackward>)
c:\Users\chris\CompetitionMMwave\pytorchTraining\curvenet_util.py:213: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  print(x.size())
torch.Size([32, 3, 1024])
c:\Users\chris\CompetitionMMwave\pytorchTraining\curvenet_util.py:356: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. 
We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if xyz.size(-1) != self.npoint:
torch.Size([32, 16, 1024])
torch.Size([32, 16, 1024])
torch.Size([32, 32, 1024])
torch.Size([32, 32, 1024])
c:\Users\chris\CompetitionMMwave\pytorchTraining\curvenet_util.py:90: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We 
can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) * 0
torch.Size([32, 64, 256])
torch.Size([32, 64, 256])
torch.Size([32, 128, 64])
torch.Size([32, 128, 64])
C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\symbolic_opset9.py:2374: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph 
will produce incorrect results.
  warnings.warn("Exporting aten::index operator of advanced indexing in opset " +
C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\symbolic_opset9.py:2332: UserWarning: Exporting aten::index operator with indices of type Byte. Only 1-D indices are supported. In any other case, this will produce an incorrect ONNX graph.
  warnings.warn("Exporting aten::index operator with indices of type Byte. "
C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\symbolic_opset9.py:584: UserWarning: This model contains a squeeze operation on dimension 1 on an input with unknown shape. Note that if the size of dimension 1 of the input is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on non-singleton dimensions, it is recommended to export this model using opset version 11 or higher.
  warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + " on an input " +
Traceback (most recent call last):
  File "C:\Users\chris\anaconda3\envs\python38\lib\runpy.py", line 192, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\chris\anaconda3\envs\python38\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\Users\chris\.vscode\extensions\ms-python.python-2021.9.1191016588\pythonFiles\lib\python\debugpy\__main__.py", line 45, in <module>
    cli.main()
  File "c:\Users\chris\.vscode\extensions\ms-python.python-2021.9.1191016588\pythonFiles\lib\python\debugpy/..\debugpy\server\cli.py", line 444, in main
    run()
  File "c:\Users\chris\.vscode\extensions\ms-python.python-2021.9.1191016588\pythonFiles\lib\python\debugpy/..\debugpy\server\cli.py", line 285, in run_file
    runpy.run_path(target_as_str, run_name=compat.force_str("__main__"))
  File "C:\Users\chris\anaconda3\envs\python38\lib\runpy.py", line 262, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "C:\Users\chris\anaconda3\envs\python38\lib\runpy.py", line 95, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "C:\Users\chris\anaconda3\envs\python38\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\Users\chris\CompetitionMMwave\pytorchTraining\pytorchTrainingTest.py", line 101, in <module>
    torch.onnx.export(torch_model,               # model being run
  File "C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\__init__.py", line 225, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\utils.py", line 85, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
    graph = _optimize_graph(graph, operator_export_type,
  File "C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\utils.py", line 203, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\__init__.py", line 263, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\utils.py", line 934, in _run_symbolic_function
    return symbolic_fn(g, *inputs, **attrs)
  File "C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\symbolic_helper.py", line 128, in wrapper
    args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
  File "C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\symbolic_helper.py", line 128, in <listcomp>
    args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
  File "C:\Users\chris\anaconda3\envs\python38\lib\site-packages\torch\onnx\symbolic_helper.py", line 80, in _parse_arg
    raise RuntimeError("Failed to export an ONNX attribute '" + v.node().kind() +
RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible

I got stuck for few days. Is there anyone who knows how to convert this model to ONNX or any suggestions about this :weary: :weary:?

It’s my first time to ask for help here, if there is anything that should be mentioned but not, please tell me.

Thanks!

The warnings and errors seem to be created by tracing the model, which doesn’t allow for a dynamic control flow so you might want to check the Tracing vs. Scripting docs and mix both modes.