Hi,
I would like to export a network that I created with onnx and then import it in opencv. Here is a minimal working example
import torch
import torch.nn as nn
import cv2 as cv
import torch.nn.functional as F
class Network(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv1d(in_channels, 1, kernel_size=3, padding='same')
nn.init.normal_(self.conv1.weight, mean=0.0, std=1.0)
self.conv2 = nn.Conv1d(2, 1, kernel_size=3, padding='same')
nn.init.normal_(self.conv2.weight, mean=0.0, std=1.0)
self.upconv = nn.ConvTranspose1d(1, out_channels, kernel_size=2, stride=2)
def forward(self, x):
convolved = self.conv1(x)
pooled = F.max_pool1d(input=convolved, kernel_size=2)
upconvolved = self.upconv(pooled)
concatenated = torch.cat((convolved, upconvolved), dim=1)
final = self.conv2(concatenated)
return final
network = Network(1, 1)
dummy_input = torch.tensor([1.0, 2.0, 3.0, 5.0], dtype=torch.float32).reshape(1, 1, -1)
output = network(dummy_input)
torch.onnx.export(network,
dummy_input,
"dummy_model.onnx",
export_params=True,
opset_version=15,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
net = cv.dnn.readNetFromONNX("dummy_model.onnx") ###### exception thrown here #####
input_array = dummy_input.numpy()
input_array = input_array.reshape(1,1,1,-1)
net.setInput(input_array)
output = net.forward()
However, there is an exception thrown at the line net = cv.dnn.readNetFromONNX("dummy_model.onnx")
which sais
cv2.error: OpenCV(4.9.0) D:\a\opencv-python\opencv-python\opencv\modules\dnn\src\onnx\onnx_importer.cpp:1053: error: (-2:Unspecified error) in function 'cv::dnn::dnn4_v20231225::ONNXImporter::handleNode'
> Node [Concat@ai.onnx]:(onnx_node!/Concat) parse error: OpenCV(4.9.0) D:\a\opencv-python\opencv-python\opencv\modules\dnn\src\layers\concat_layer.cpp:105: error: (-215:Assertion failed) curShape.size() == outputs[0].size() in function 'cv::dnn::ConcatLayerImpl::getMemoryShapes'
what is going on here and how can I resolve the issue?
I am using torch version 2.3.0 and opencv version 4.9.0