Export object detection model to ONNX:empty output by ONNX inference

I try to convert my PyTorch object detection model (Faster R-CNN) to ONNX. I have two setups. The first one is working correctly but I want to use the second one for deployment reasons. The difference lies in the example image which I use for the export of the function torch.onnx.export().

In the first setup I use a real image as input for the ONNX export. But in a official tutorial they say that I can use a dummy input, which should have the same size as the model expects the input. So I created a tensor with the same shape but with random values. The export in both setups is working correctly. But the second setup does not deliver the desired results after inference with the ONNX runtime. The code and the exemplary output can be found below.

Setup 1

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True)
...
checkpoint = torch.load(model_state_dict_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

to_tensor = transforms.ToTensor()
img_rgb = Image.open(image_path_model).convert('RGB')
img_rgb = to_tensor(img_rgb)
img_rgb.unsqueeze_(0)    

torch.onnx.export(model, img_rgb, "detection.onnx", opset_version=11) 

I get no error and the export works. Afterwards I run the model with the ONNX runtime and I get the following output:

[array([[704.0696  , 535.19556 , 944.8986  , 786.1619  ],
         ...], dtype=float32),
array([2, 2, 2, 2, 2, 1, 1], dtype=int64),
array([0.9994363 , 0.9984769 , 0.99816966, ...], dtype=float32)]

The output is as I expect it to be (Bounding boxes, object classes and probabilities).

Setup 2

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True)
...
checkpoint = torch.load(model_state_dict_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

img_rgb = torch.randn(1, 3, 1024, 1024)   

torch.onnx.export(model, img_rgb, "detection.onnx", opset_version=11) 

Like in setup 1 I get no error and the export works. Afterwards I run the model with the ONNX runtime and with the same image as in setup 1 and I get the following output:

[array([], shape=(0, 4), dtype=float32),
array([], dtype=int64),
array([], dtype=float32)]

It is just an empty array.

What is wrong with the second setup? I am new to ONNX. The export runs the model. Do I have to provide an input on which the model also recognizes objects and therefore the dummy input with random values does not work? Is the statement “The values in this can be random as long as it is the right type and size.” only valid for the provided tutorial?

1 Like

Related: https://github.com/pytorch/vision/issues/1706#issuecomment-658198204

1 Like

I had a related problem with Faster R-CNN and onnxruntime. When I export the model using a dummy input which doesn’t contain an object, at inference time the model works only with similar kind of images. Meaning, if there is no detection it is fine, but If the model detects an object, it crashes with the message The input tensor cannot be reshaped to the requested shape. Same thing happens vice versa. If I export the model using a real image, than onnxruntime crashes when there is no detection at inference time (empty tensor). So, at some point during the export the reshape function is hardcoded for a certain tensor shape and can’t handle both detection and no-detection cases. My solution was that I concatenated the predictions with zeros, so that even though the predictions is an empty tensor, it would now contain bunch of zeros, and reshape would take place safely. However, I had to use a dummy input (which leads to no detection) during export for this scheme to work. I added the following code at this line: vision/roi_heads.py at 3e27eb2104c01a7a55f5b4e82d19d9b6612715b5 · pytorch/vision · GitHub, right before pred_boxes_list = ...

# Added to fix "reshape mismatch error" when there is no detection
pred_boxes = torch.cat((pred_boxes.view(-1, num_classes, 4), torch.zeros((1, num_classes, 4), device=device)), dim=0)
pred_scores = torch.cat((pred_scores, torch.zeros((1, num_classes), device=device)), dim=0)
boxes_per_image[0] += 1