ONNXRuntimeError: failed:Node (Gather_346) Op (Gather) [ShapeInferenceError] axis must be in [-r, r-1]

Hello all,

Recently, I’m trying to deploy my model in a real environment. The model is a Fater-RCNN based object recognition model, as proposed by Anderson et al Bottom-up-attention. The model is implemented with Detectron.

The first try was with a web service (Flask plus Redis Queue), which works but with delays due to connection and transition issues. Therefore, an efficient solution was wished.

The second try is to export the PyTorch model and use it in the JAVA application, leveraged by DJL or DL4J. For that, I need to export the model as a ONNX graph at first. Here is the work I done as so far.

import onnxruntime as ort
import numpy as np
import torch
import onnx
import cv2

// Initialize model with checkpoint
model = BoAModel()
model.eval()

// Load image    
cv_image = cv2.imread(image_path)
// Transformation, transform_gen () is used to resize the image
image, transforms = T.apply_transform_gens(transform_gen, cv_image)

// Model input
dataset_dict = [torch.as_tensor(image.transpose(2, 0, 1).astype("float32"), device="cpu"), torch.as_tensor([float(image_shape[0]) / float(h)], device="cpu")]

// Export onnx 
model_onnx_path = "torch_model_boa_jit_torch1.5.1.onnx"

batch_size = 1
torch.onnx.export(model=model,  # model to be exported
                  args= data,  # model input 
                  f=model_onnx_path,  # where to save the model
                  export_params=True,  # store the trained parameter weights inside the model file
                  opset_version=11,  # the ONNX version to export the model to
                  do_constant_folding=False,  # whether to execute constant folding for optimization
                  input_names=['input'],  # the model's input names
                  output_names=['output'],  # the model's output names
                  dynamic_axes={'input': {0: 'batch_size'}, 
                                'output':{0: 'batch_size'}},
                  verbose=True)

By exporting, there were a couple of warnings that have been asked and googled. It is not critical as so far in my case. After the exporting, I had a sanity check, as shown as follows:

# Load the ONNX model
model = onnx.load(model_onnx_path)

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

So far so good. So I tried to load it in ONNX Runtime with the following codes:

sess_options = ort.SessionOptions()
# Below is for optimizing performance
sess_options = ort.SessionOptions()
# sess_options.intra_op_num_threads = 24
# ...
ort_session = ort.InferenceSession(model_onnx_path, sess_options=sess_options)

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# Compute ONNX Runtime output prediction (use the same data that model works well)
ort_inputs = {ort_session.get_inputs()[0].name: [to_numpy(torch.unsqueeze(torch.as_tensor(image.transpose(2, 0, 1).astype("float32"), device="cpu"), dim=0)), [torch.as_tensor(1.6)]]}
ort_outs = ort_session.run(None, ort_inputs)
print("ort_outs", ort_outs)

Then I got a ONNXRuntime Error at line (ort.InferenceSession(model_onnx_path,):

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from torch_model_boa_jit_torch1.5.1.onnx failed:Node (Gather_346) Op (Gather) [ShapeInferenceError] axis must be in [-r, r-1]

As the error shows, there was a problem at node Gather_346, it is shown in the following figure:

I’m not quite sure whether there was an error at this step (e.g., Tensor data is empty in the figure). I check the verbose log of export. This Gather operation was given in box_regression (at line widths = boxes[:, 2] - boxes[:, 0] + OFFSET, where the box prediction will be calculated based on proposed regions and anchors, as shown as below:

  # torch.onnx.export log
  %964 : Float(22800) = onnx::Gather[axis=1](%962, %963) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:169:0
  %965 : Float(0) = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %966 : Tensor = onnx::Constant[value={0}]()
  %967 : Float(22800) = onnx::Gather[axis=1](%965, %966) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:169:0
  %968 : Float(22800) = onnx::Sub(%964, %967) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:169:0

The codes where the errors (Gather_346) might be triggered possibly are as follows:

    def apply_deltas(self, deltas, boxes):
        """
        Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.

        Args:
            deltas (Tensor): transformation deltas of shape (N, k*4)
            boxes (Tensor): boxes to transform, of shape (N, 4)
        """
        assert torch.isfinite(deltas).all().item(), "Box regression deltas become infinite or NaN!"
        boxes = boxes.to(deltas.dtype)
        print("apply_deltas boxes.shape", boxes.shape)
        OFFSET = 1  
        widths = boxes[:, 2] - boxes[:, 0] + OFFSET 
        heights = boxes[:, 3] - boxes[:, 1] + OFFSET
        ctr_x = boxes[:, 0] + 0.5 * widths
        ctr_y = boxes[:, 1] + 0.5 * heights

        wx, wy, ww, wh = self.weights
        dx = deltas[:, 0::4] / wx
        dy = deltas[:, 1::4] / wy
        dw = deltas[:, 2::4] / ww
        dh = deltas[:, 3::4] / wh

        # clamping too large values into torch.exp()
        dw = torch.clamp(dw, max=self.scale_clamp)
        dh = torch.clamp(dh, max=self.scale_clamp)

        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
        pred_w = torch.exp(dw) * widths[:, None]
        pred_h = torch.exp(dh) * heights[:, None]

        pred_boxes = torch.zeros_like(deltas)
        pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w  # x1
        pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h  # y1
        pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w  # x2
        pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h  # y2
        print("pred_boxes.shape", pred_boxes.shape)
        return pred_boxes

Here are the configurations of the python environment:

# python working environment  (1.5.1)
*  virtualenv         15.0.1
   python              3.6
   torch               1.5.1
   torchvision         0.6.1
   onnx                1.10.2
   onnxruntime         1.9.0
   ubuntu              16.04
   CUDA                10.2 (440.33.01, not used in test as so far)

In addition, I’m not sure whether it is a torch version issue, so I had another try with the following environment:

# python working environment (1.10.0)
*  virtualenv         15.0.1
   python              3.6
   torch               1.10.0
   torchvision         0.11.1
   onnx                1.10.2
   onnxruntime         1.9.0
   ubuntu              16.04
   CUDA                10.2 (440.33.01, not used in test as so far)

The export procedure was simply crashed in this working environment with the following gdb debug logs.

/media/WorkSpace/Development_Repository/Working/models/detectron2/modeling/poolers.py:73: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  (len(box_tensor), 1), batch_index, dtype=box_tensor.dtype, device=box_tensor.device

Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x00007fffc9e57c08 in std::_Function_handler<void (onnx_torch::InferenceContext&), onnx_torch::OpSchema onnx_torch::GetOpSchema<onnx_torch::ConstantOfShape_Onnx_ver9>()::{lambda(onnx_torch::InferenceContext&)#1}>::_M_invoke(std::_Any_data const&, onnx_torch::InferenceContext&) ()
   from /media/WorkSpace/Development_Repository/Working/models/venv3_torch/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so

So I stay with the 1.5.1 environment. Additionally, I tested this exported ONNX model in JAVA with DL4J and got the following (as same as python) errors:

2021-11-25 15:41:30.1671963 [I:onnxruntime:, inference_session.cc:230 onnxruntime::InferenceSession::ConstructorCommon::<lambda_dcdcfd37ad4a704b0bd7e98f885edfd8>::operator ()] Flush-to-zero and denormal-as-zero are off
2021-11-25 15:41:30.1672832 [I:onnxruntime:, inference_session.cc:237 onnxruntime::InferenceSession::ConstructorCommon] Creating and using per session threadpools since use_per_session_threads_ is true
WARNING: Since openmp is enabled in this build, this API cannot be used to configure intra op num threads. Please use the openmp environment variables to control the number of threads.
Exception caught java.lang.RuntimeException: Load model from C:\Users\ps\IdeaProjects\Demo_DL4J\src\main\resources\checkpoint_boa.onnx failed:Node (Gather_341) Op (Gather) [ShapeInferenceError] axis must be in [-r, r-1]

Question:
Could somebody help me to figure out what might be going on here? In this test, the model and data are launched on CPU. In order to narrow down the errors, I followed the steps as Exporting Fasterrcnn Resnet50 fpn to ONNX suggested I works in my test environment 1.51. However, I do not make it successfully on my own model. To make sure the inference works correctly, I used the same data that the model usually works well.

Any input will be appreciated!

Was really confused me is: the model works well in the normal reference step. With the same data, there are some unexcepted hidden cases in the export.

In addition, I used the following trace function for getting a model saved by jit since torch.onnx.export and torch.jit.trace do the same work eventually. After calling the following function,

traced_model = torch.jit.trace(model, example_inputs=(data[0], data[1]))

I had the errors like

forward() takes 2 positional arguments but 3 were given

For solving this error, I modified the signature of forward function of my model. As so far it is quiet. Go back to jit.trace.
I also changed the slice function widths = boxes[:, 2] - boxes[:, 0] + OFFSET to ==>

widths = torch.select(boxes, dim=1, index=2) - torch.select(boxes, dim=1, index=0) + OFFSET
to make sure that this step is correct. However, I got the same error as the previous one,

RuntimeError: select() cannot be applied to a 0-dim tensor.

This means the input data boxes is empty, yet the tract log shows the data is not empty, as the following log ( of the apply_deltas function) shows, the widths have a size of 22800.

  %958 : Float(22800) = onnx::Gather[axis=1](%956, %957) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:169:0

The exported model has empty data for this operation, as shown inSelection_545.

How could it happen at this step?
The rest complete logs are as follows:

  %959 : Float(0) = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %960 : Tensor = onnx::Constant[value={0}]()
  %961 : Float(22800) = onnx::Gather[axis=1](%959, %960) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:169:0
  %962 : Float(22800) = onnx::Sub(%958, %961) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:169:0
  %963 : Float() = onnx::Constant[value={1}]()
  %964 : Float(22800) = onnx::Add(%962, %963) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:186:0
  %965 : Float(0) = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %966 : Tensor = onnx::Constant[value={3}]()
  %967 : Float(22800) = onnx::Gather[axis=1](%965, %966) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:170:0
  %968 : Float(0) = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %969 : Tensor = onnx::Constant[value={1}]()
  %970 : Float(22800) = onnx::Gather[axis=1](%968, %969) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:170:0
  %971 : Float(22800) = onnx::Sub(%967, %970) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:170:0
  %972 : Float() = onnx::Constant[value={1}]()
  %973 : Float(22800) = onnx::Add(%971, %972) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:187:0
  %974 : Float(0) = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %975 : Tensor = onnx::Constant[value={0}]()
  %976 : Float(22800) = onnx::Gather[axis=1](%974, %975) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:171:0
  %977 : Float() = onnx::Constant[value={0.5}]()
  %978 : Float(22800) = onnx::Mul(%964, %977)
  %979 : Float(22800) = onnx::Add(%976, %978) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:184:0
  %980 : Float(0) = onnx::Constant[value=[ CPUFloatType{0} ]]()
  %981 : Tensor = onnx::Constant[value={1}]()
  %982 : Float(22800) = onnx::Gather[axis=1](%980, %981) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:172:0
  %983 : Float() = onnx::Constant[value={0.5}]()
  %984 : Float(22800) = onnx::Mul(%973, %983)
  %985 : Float(22800) = onnx::Add(%982, %984) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:185:0
  %986 : Tensor = onnx::Constant[value={1}]()
  %987 : Tensor = onnx::Constant[value={0}]()
  %988 : Tensor = onnx::Constant[value={9223372036854775807}]()
  %989 : Tensor = onnx::Constant[value={4}]()
  %990 : Float(22800, 1) = onnx::Slice(%955, %987, %988, %986, %989) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:175:0
  %991 : Tensor = onnx::Constant[value={1}]()
  %992 : Tensor = onnx::Constant[value={1}]()
  %993 : Tensor = onnx::Constant[value={9223372036854775807}]()
  %994 : Tensor = onnx::Constant[value={4}]()
  %995 : Float(22800, 1) = onnx::Slice(%955, %992, %993, %991, %994) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:176:0
  %996 : Tensor = onnx::Constant[value={1}]()
  %997 : Tensor = onnx::Constant[value={2}]()
  %998 : Tensor = onnx::Constant[value={9223372036854775807}]()
  %999 : Tensor = onnx::Constant[value={4}]()
  %1000 : Float(22800, 1) = onnx::Slice(%955, %997, %998, %996, %999) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:177:0
  %1001 : Tensor = onnx::Constant[value={1}]()
  %1002 : Tensor = onnx::Constant[value={3}]()
  %1003 : Tensor = onnx::Constant[value={9223372036854775807}]()
  %1004 : Tensor = onnx::Constant[value={4}]()
  %1005 : Float(22800, 1) = onnx::Slice(%955, %1002, %1003, %1001, %1004) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:178:0
  %1006 : None = prim::Constant()
  %1007 : Double() = onnx::Constant[value={4.13517}]()
  %1008 : Tensor = onnx::Cast[to=1](%1007)
  %1009 : Float(22800, 1) = onnx::Clip(%1000, %1006, %1008) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:181:0
  %1010 : None = prim::Constant()
  %1011 : Double() = onnx::Constant[value={4.13517}]()
  %1012 : Tensor = onnx::Cast[to=1](%1011)
  %1013 : Float(22800, 1) = onnx::Clip(%1005, %1010, %1012) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:182:0
  %1014 : Float(22800, 1) = onnx::Unsqueeze[axes=[1]](%964) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:184:0
  %1015 : Float(22800, 1) = onnx::Mul(%990, %1014) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:184:0
  %1016 : Float(22800, 1) = onnx::Unsqueeze[axes=[1]](%979) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:184:0
  %1017 : Float(22800, 1) = onnx::Add(%1015, %1016) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:184:0
  %1018 : Float(22800, 1) = onnx::Unsqueeze[axes=[1]](%973) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:185:0
  %1019 : Float(22800, 1) = onnx::Mul(%995, %1018) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:185:0
  %1020 : Float(22800, 1) = onnx::Unsqueeze[axes=[1]](%985) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:185:0
  %1021 : Float(22800, 1) = onnx::Add(%1019, %1020) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:185:0
  %1022 : Float(22800, 1) = onnx::Exp(%1009) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:186:0
  %1023 : Float(22800, 1) = onnx::Unsqueeze[axes=[1]](%964) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:186:0
  %1024 : Float(22800, 1) = onnx::Mul(%1022, %1023) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:186:0
  %1025 : Float(22800, 1) = onnx::Exp(%1013) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:187:0
  %1026 : Float(22800, 1) = onnx::Unsqueeze[axes=[1]](%973) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:187:0
  %1027 : Float(22800, 1) = onnx::Mul(%1025, %1026) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:187:0
  %1028 : Tensor = onnx::Shape(%955)
  %1029 : Float(22800, 4) = onnx::ConstantOfShape[value={0}](%1028) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:189:0
  %1030 : Float() = onnx::Constant[value={0.5}]()
  %1031 : Float(22800, 1) = onnx::Mul(%1024, %1030)
  %1032 : Float(22800, 1) = onnx::Sub(%1017, %1031) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:190:0
  %1033 : Float(22800, 1) = onnx::Flatten[axis=1](%1032) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:190:0
  %1034 : Tensor = onnx::Shape(%1029)
  %1035 : Tensor = onnx::Constant[value={0}]()
  %1036 : Long() = onnx::Gather[axis=0](%1034, %1035)
  %1037 : Tensor = onnx::Cast[to=7](%1036)
  %1038 : Tensor = onnx::Constant[value={0}]()
  %1039 : Tensor = onnx::Constant[value={1}]()
  %1040 : Tensor = onnx::Range(%1038, %1037, %1039)
  %1041 : Tensor = onnx::Shape(%1029)
  %1042 : Tensor = onnx::Constant[value={1}]()
  %1043 : Long() = onnx::Gather[axis=0](%1041, %1042)
  %1044 : Tensor = onnx::Cast[to=7](%1043)
  %1045 : Tensor = onnx::Constant[value={0}]()
  %1046 : Tensor = onnx::Constant[value={1}]()
  %1047 : Tensor = onnx::Range(%1045, %1044, %1046)
  %1048 : Tensor = onnx::Constant[value={0}]()
  %1049 : Tensor = onnx::Constant[value={0}]()
  %1050 : Tensor = onnx::Constant[value={9223372036854775807}]()
  %1051 : Tensor = onnx::Constant[value={4}]()
  %1052 : Tensor = onnx::Slice(%1047, %1049, %1050, %1048, %1051)
  %1053 : Tensor = onnx::Constant[value=-1  1 [ CPULongType{2} ]]()
  %1054 : Tensor = onnx::Reshape(%1040, %1053)
  %1055 : Tensor = onnx::Constant[value={-1}]()
  %1056 : Tensor = onnx::Reshape(%1052, %1055)
  %1057 : Tensor = onnx::Add(%1054, %1056)
  %1058 : Tensor = onnx::Shape(%1057)
  %1059 : Tensor = onnx::Shape(%1058)
  %1060 : Tensor = onnx::ConstantOfShape[value={1}](%1059)
  %1061 : Long() = onnx::Constant[value={-1}]()
  %1062 : LongTensor = onnx::Mul(%1060, %1061)
  %1063 : Tensor = onnx::Equal(%1058, %1062)
  %1064 : Tensor = onnx::Where(%1063, %1060, %1058)
  %1065 : Tensor = onnx::Expand(%1054, %1064)
  %1066 : Tensor = onnx::Unsqueeze[axes=[-1]](%1065)
  %1067 : Tensor = onnx::Shape(%1058)
  %1068 : Tensor = onnx::ConstantOfShape[value={1}](%1067)
  %1069 : Long() = onnx::Constant[value={-1}]()
  %1070 : LongTensor = onnx::Mul(%1068, %1069)
  %1071 : Tensor = onnx::Equal(%1058, %1070)
  %1072 : Tensor = onnx::Where(%1071, %1068, %1058)
  %1073 : Tensor = onnx::Expand(%1056, %1072)
  %1074 : Tensor = onnx::Unsqueeze[axes=[-1]](%1073)
  %1075 : Tensor = onnx::Concat[axis=-1](%1066, %1074)
  %1076 : Tensor = onnx::Shape(%1029)
  %1077 : Tensor = onnx::Constant[value={0}]()
  %1078 : Tensor = onnx::Constant[value={2}]()
  %1079 : Tensor = onnx::Constant[value={9223372036854775807}]()
  %1080 : Tensor = onnx::Slice(%1076, %1078, %1079, %1077)
  %1081 : Tensor = onnx::Concat[axis=0](%1058, %1080)
  %1082 : Tensor = onnx::Reshape(%1033, %1081)
  %1083 : Float(22800, 1) = onnx::ScatterND(%1029, %1075, %1082)
  %1084 : Float() = onnx::Constant[value={0.5}]()
  %1085 : Float(22800, 1) = onnx::Mul(%1027, %1084)
  %1086 : Float(22800, 1) = onnx::Sub(%1021, %1085) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:191:0
  %1087 : Float(22800, 1) = onnx::Flatten[axis=1](%1086) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:191:0
  %1088 : Tensor = onnx::Shape(%1083)
  %1089 : Tensor = onnx::Constant[value={0}]()
  %1090 : Long() = onnx::Gather[axis=0](%1088, %1089)
  %1091 : Tensor = onnx::Cast[to=7](%1090)
  %1092 : Tensor = onnx::Constant[value={0}]()
  %1093 : Tensor = onnx::Constant[value={1}]()
  %1094 : Tensor = onnx::Range(%1092, %1091, %1093)
  %1095 : Tensor = onnx::Shape(%1083)
  %1096 : Tensor = onnx::Constant[value={1}]()
  %1097 : Long() = onnx::Gather[axis=0](%1095, %1096)
  %1098 : Tensor = onnx::Cast[to=7](%1097)
  %1099 : Tensor = onnx::Constant[value={0}]()
  %1100 : Tensor = onnx::Constant[value={1}]()
  %1101 : Tensor = onnx::Range(%1099, %1098, %1100)
  %1102 : Tensor = onnx::Constant[value={0}]()
  %1103 : Tensor = onnx::Constant[value={1}]()
  %1104 : Tensor = onnx::Constant[value={9223372036854775807}]()
  %1105 : Tensor = onnx::Constant[value={4}]()
  %1106 : Tensor = onnx::Slice(%1101, %1103, %1104, %1102, %1105)
  %1107 : Tensor = onnx::Constant[value=-1  1 [ CPULongType{2} ]]()
  %1108 : Tensor = onnx::Reshape(%1094, %1107)
  %1109 : Tensor = onnx::Constant[value={-1}]()
  %1110 : Tensor = onnx::Reshape(%1106, %1109)
  %1111 : Tensor = onnx::Add(%1108, %1110)
  %1112 : Tensor = onnx::Shape(%1111)
  %1113 : Tensor = onnx::Shape(%1112)
  %1114 : Tensor = onnx::ConstantOfShape[value={1}](%1113)
  %1115 : Long() = onnx::Constant[value={-1}]()
  %1116 : LongTensor = onnx::Mul(%1114, %1115)
  %1117 : Tensor = onnx::Equal(%1112, %1116)
  %1118 : Tensor = onnx::Where(%1117, %1114, %1112)
  %1119 : Tensor = onnx::Expand(%1108, %1118)
  %1120 : Tensor = onnx::Unsqueeze[axes=[-1]](%1119)
  %1121 : Tensor = onnx::Shape(%1112)
  %1122 : Tensor = onnx::ConstantOfShape[value={1}](%1121)
  %1123 : Long() = onnx::Constant[value={-1}]()
  %1124 : LongTensor = onnx::Mul(%1122, %1123)
  %1125 : Tensor = onnx::Equal(%1112, %1124)
  %1126 : Tensor = onnx::Where(%1125, %1122, %1112)
  %1127 : Tensor = onnx::Expand(%1110, %1126)
  %1128 : Tensor = onnx::Unsqueeze[axes=[-1]](%1127)
  %1129 : Tensor = onnx::Concat[axis=-1](%1120, %1128)
  %1130 : Tensor = onnx::Shape(%1083)
  %1131 : Tensor = onnx::Constant[value={0}]()
  %1132 : Tensor = onnx::Constant[value={2}]()
  %1133 : Tensor = onnx::Constant[value={9223372036854775807}]()
  %1134 : Tensor = onnx::Slice(%1130, %1132, %1133, %1131)
  %1135 : Tensor = onnx::Concat[axis=0](%1112, %1134)
  %1136 : Tensor = onnx::Reshape(%1087, %1135)
  %1137 : Float(22800, 1) = onnx::ScatterND(%1083, %1129, %1136)
  %1138 : Float() = onnx::Constant[value={0.5}]()
  %1139 : Float(22800, 1) = onnx::Mul(%1024, %1138)
  %1140 : Float(22800, 1) = onnx::Add(%1017, %1139) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:192:0
  %1141 : Float(22800, 1) = onnx::Flatten[axis=1](%1140) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:192:0
  %1142 : Tensor = onnx::Shape(%1137)
  %1143 : Tensor = onnx::Constant[value={0}]()
  %1144 : Long() = onnx::Gather[axis=0](%1142, %1143)
  %1145 : Tensor = onnx::Cast[to=7](%1144)
  %1146 : Tensor = onnx::Constant[value={0}]()
  %1147 : Tensor = onnx::Constant[value={1}]()
  %1148 : Tensor = onnx::Range(%1146, %1145, %1147)
  %1149 : Tensor = onnx::Shape(%1137)
  %1150 : Tensor = onnx::Constant[value={1}]()
  %1151 : Long() = onnx::Gather[axis=0](%1149, %1150)
  %1152 : Tensor = onnx::Cast[to=7](%1151)
  %1153 : Tensor = onnx::Constant[value={0}]()
  %1154 : Tensor = onnx::Constant[value={1}]()
  %1155 : Tensor = onnx::Range(%1153, %1152, %1154)
  %1156 : Tensor = onnx::Constant[value={0}]()
  %1157 : Tensor = onnx::Constant[value={2}]()
  %1158 : Tensor = onnx::Constant[value={9223372036854775807}]()
  %1159 : Tensor = onnx::Constant[value={4}]()
  %1160 : Tensor = onnx::Slice(%1155, %1157, %1158, %1156, %1159)
  %1161 : Tensor = onnx::Constant[value=-1  1 [ CPULongType{2} ]]()
  %1162 : Tensor = onnx::Reshape(%1148, %1161)
  %1163 : Tensor = onnx::Constant[value={-1}]()
  %1164 : Tensor = onnx::Reshape(%1160, %1163)
  %1165 : Tensor = onnx::Add(%1162, %1164)
  %1166 : Tensor = onnx::Shape(%1165)
  %1167 : Tensor = onnx::Shape(%1166)
  %1168 : Tensor = onnx::ConstantOfShape[value={1}](%1167)
  %1169 : Long() = onnx::Constant[value={-1}]()
  %1170 : LongTensor = onnx::Mul(%1168, %1169)
  %1171 : Tensor = onnx::Equal(%1166, %1170)
  %1172 : Tensor = onnx::Where(%1171, %1168, %1166)
  %1173 : Tensor = onnx::Expand(%1162, %1172)
  %1174 : Tensor = onnx::Unsqueeze[axes=[-1]](%1173)
  %1175 : Tensor = onnx::Shape(%1166)
  %1176 : Tensor = onnx::ConstantOfShape[value={1}](%1175)
  %1177 : Long() = onnx::Constant[value={-1}]()
  %1178 : LongTensor = onnx::Mul(%1176, %1177)
  %1179 : Tensor = onnx::Equal(%1166, %1178)
  %1180 : Tensor = onnx::Where(%1179, %1176, %1166)
  %1181 : Tensor = onnx::Expand(%1164, %1180)
  %1182 : Tensor = onnx::Unsqueeze[axes=[-1]](%1181)
  %1183 : Tensor = onnx::Concat[axis=-1](%1174, %1182)
  %1184 : Tensor = onnx::Shape(%1137)
  %1185 : Tensor = onnx::Constant[value={0}]()
  %1186 : Tensor = onnx::Constant[value={2}]()
  %1187 : Tensor = onnx::Constant[value={9223372036854775807}]()
  %1188 : Tensor = onnx::Slice(%1184, %1186, %1187, %1185)
  %1189 : Tensor = onnx::Concat[axis=0](%1166, %1188)
  %1190 : Tensor = onnx::Reshape(%1141, %1189)
  %1191 : Float(22800, 1) = onnx::ScatterND(%1137, %1183, %1190)
  %1192 : Float() = onnx::Constant[value={0.5}]()
  %1193 : Float(22800, 1) = onnx::Mul(%1027, %1192)
  %1194 : Float(22800, 1) = onnx::Add(%1021, %1193) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:193:0
  %1195 : Float(22800, 1) = onnx::Flatten[axis=1](%1194) # /media/WorkSpace/Development_Repository/Working/models/boa/box_regression.py:193:0
  %1196 : Tensor = onnx::Shape(%1191)

The output model

What could be the reason for such an issue? The tract logs show that the data is not empty. But the exported model (as shown in the figure) has empty data. Could somebody give me some inputs to locate the problem? Thanks in advance.

Hi @Anakin .
It is really hard to tell, but from my experience the problem is in the tracing process.
When tracing, your model produces no boxes. few things can cause this:

  1. you are tracing with random input instead of a real image
  2. you are tracing an untrained model (so it produces no prediction boxes)
  3. you have some threshold somewhere which filters the predictions.

look in the code if the predictions are filtered by a threshold or NMS (Non max Suppression - may also have an internal threshold on the confidence). set the threshold to 0 so all boxes will pass the filter and see if this resolves your issue.