Big trouble in export deformable detr

when i try to export the deformable detr model into torchscript,it shows the error message:

Could not export Python function call 'MSDeformAttnFunction'**. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/root/autodl-tmp/project/deepsolo/adet/layers/ms_deform_attn.py(165): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/deepsolo/adet/layers/deformable_transformer.py(286): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/deepsolo/adet/layers/deformable_transformer.py(413): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/deepsolo/adet/layers/deformable_transformer.py(173): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/deepsolo/adet/modeling/model/detection_transformer.py(200): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/deepsolo/adet/modeling/text_spotter_v1.py(222): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/detectron2/export/flatten.py(259): <lambda>
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/detectron2/export/flatten.py(294): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/jit/_trace.py(952): trace_module
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/jit/_trace.py(735): trace
deploy/export_model00.py(126): export_tracing
deploy/export_model00.py(226): <module>

and i tried the method:
image

it exported successfully,but the exported model doesn’t work! it only can inference on the image that i used for exporting the model,but for other image,it cann’t work. it shows the error message:

/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return forward_call(*input, **kwargs)
Traceback (most recent call last):
  File "/root/autodl-tmp/project/deploy/export_model.py", line 264, in <module>
    out1 = m(data)
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/detectron2/export/flatten.py", line 9, in forward
  def forward(self: __torch__.detectron2.export.flatten.TracingAdapter,
    argument_1: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
    _0, _1, _2, _3, _4, _5, _6, = (self.model).forward(argument_1, )
                                   ~~~~~~~~~~~~~~~~~~~ <--- HERE
    return (_0, _1, _2, _3, _4, _5, _6)
  File "code/__torch__/adet/modeling/text_spotter.py", line 23, in forward
    batched_imgs = torch.unsqueeze_(_7, 0)
    x0 = torch.contiguous(batched_imgs)
    _8, _9, _10, _11, = (_0).forward(x0, image_size, )
                         ~~~~~~~~~~~ <--- HERE
    _12 = torch.softmax(_9, -1)
    prob = torch.sigmoid(torch.mean(_8, [-2]))
  File "code/__torch__/adet/modeling/model/detection_transformer.py", line 50, in forward
    _29 = getattr(self.input_proj, "1")
    _30 = getattr(self.input_proj, "0")
    _31 = (self.backbone).forward(x, image_size, )
           ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, = _31
    _58 = (_30).forward(_32, )
  File "code/__torch__/adet/modeling/text_spotter.py", line 104, in forward
    image_size: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
    _61 = getattr(self, "1")
    _62 = (getattr(self, "0")).forward(x, image_size, )
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _63, _64, _65, _66, _67, _68, _69, = _62
    pos_embed = torch.to((_61).forward(_63, ), 6)
  File "code/__torch__/adet/modeling/text_spotter.py", line 143, in forward
    _92 = torch.slice(torch.slice(_91, 0, 0, 125), 1, 0, 138)
    _93 = torch.view(CONSTANTS.c2, annotate(List[int], []))
    _94 = torch.copy_(_92, torch.expand(_93, [125, 138]))
          ~~~~~~~~~~~ <--- HERE
    masks_per_feature_level0 = torch.ones([_85, _86, _87], dtype=11, layout=None, device=torch.device("cpu"), pin_memory=False)
    _95 = torch.select(masks_per_feature_level0, 0, 0)

Traceback of TorchScript, original code (most recent call last):
/root/autodl-tmp/project/adet/modeling/text_spotter.py(60): mask_out_padding
/root/autodl-tmp/project/adet/modeling/text_spotter.py(43): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/adet/modeling/text_spotter.py(21): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/adet/modeling/model/detection_transformer.py(168): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/adet/modeling/text_spotter.py(220): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/detectron2/export/flatten.py(259): <lambda>
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/detectron2/export/flatten.py(294): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/jit/_trace.py(952): trace_module
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/jit/_trace.py(735): trace
/root/autodl-tmp/project/deploy/export_model.py(125): export_tracing
/root/autodl-tmp/project/deploy/export_model.py(224): <module>
/root/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py(18): execfile
/root/.pycharm_helpers/pydev/pydevd.py(1496): _exec
/root/.pycharm_helpers/pydev/pydevd.py(1489): run
/root/.pycharm_helpers/pydev/pydevd.py(2177): main
/root/.pycharm_helpers/pydev/pydevd.py(2195): <module>
RuntimeError: The size of tensor a (50) must match the size of tensor b (125) at non-singleton dimension 0

btw, after failed to export the model into torchscript,i tried to export it into onnx,but i still met some problem and failed:
when i use torch1.9 and cuda 111, the error message is:

Traceback (most recent call last):
  File "deploy/export_model00.py", line 226, in <module>
    sample_inputs = get_sample_inputs(args)
  File "deploy/export_model00.py", line 132, in export_tracing
    torch.onnx.export(traceable_model, (image,), f,opset_version=11)
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/onnx/__init__.py", line 275, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/onnx/utils.py", line 88, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/onnx/utils.py", line 689, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/onnx/utils.py", line 501, in _model_to_graph
    params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
RuntimeError: expected scalar type Long but found Float

i searched on the Internet,and find a method: upgrade my torch to 1.10, i changed my torch to 1.10,still error:

Traceback (most recent call last):
  File "deploy/export_model00.py", line 227, in <module>
    exported_model = export_tracing(torch_model, sample_inputs)
  File "deploy/export_model00.py", line 132, in export_tracing
    torch.onnx.export(traceable_model, image, f, input_names=['input'], output_names=['output'], opset_version=11,dynamic_axes={'input':[1,2],'output':[1,2]})
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/onnx/__init__.py", line 316, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/onnx/utils.py", line 107, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/onnx/utils.py", line 724, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/onnx/utils.py", line 493, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/onnx/utils.py", line 437, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/onnx/utils.py", line 388, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
RuntimeError: 0INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":584, please report a bug to PyTorch. We don't have an op for aten::fill_ but it isn't a special case.  Argument types: Tensor, bool, 

@ptrblck i’ll appreciate if you could reply!

Details of problem I met:
What I’m tring to do:
I try to exported the model so that I can inference in c++ environment.(the model is from GitHub - ViTAE-Transformer/DeepSolo: The official repo for [CVPR'23] "DeepSolo: Let Transformer Decoder with Explicit Points Solo for Text Spotting" & [ArXiv'23] "DeepSolo++: Let Transformer Decoder with Explicit Points Solo for Text Spotting" )

How I export the model:
I use the official code(detectron2/tools/deploy at main · facebookresearch/detectron2 · GitHub):export_model.py
parameters:
–export-method tracing
–format torchscript
MODEL.DEVICE cpu
MODEL.WEIGHTS (is from GitHub - ViTAE-Transformer/DeepSolo: The official repo for [CVPR'23] "DeepSolo: Let Transformer Decoder with Explicit Points Solo for Text Spotting" & [ArXiv'23] "DeepSolo++: Let Transformer Decoder with Explicit Points Solo for Text Spotting"

–config-file:

After running the export_model.py, it shows error message:

Could not export Python function call 'MSDeformAttnFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:

Then I search on the internet, and finally solve this by (Can model convert to torchscript? · Issue #37 · facebookresearch/Mask2Former · GitHub)

Then I got the exported model: model.ts file,

This is how the input image been processed:

But when I try to use the model.ts to inference image,

Model_script=torch.jit.load("model.ts")
sample_inputs_onlyimg=sample_inputs[0]["image"]
out11 = Model_script (sample_inputs_onlyimg)

it only can inference the image “args.sample_image”, but when I change “args.sample_image” to another image,it shows the error massage:

/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py:1051: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return forward_call(*input, **kwargs)
Traceback (most recent call last):
  File "/root/autodl-tmp/project/deploy/export_model.py", line 264, in <module>
    out1 = m(data)
  File "/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/detectron2/export/flatten.py", line 9, in forward
  def forward(self: __torch__.detectron2.export.flatten.TracingAdapter,
    argument_1: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
    _0, _1, _2, _3, _4, _5, _6, = (self.model).forward(argument_1, )
                                   ~~~~~~~~~~~~~~~~~~~ <--- HERE
    return (_0, _1, _2, _3, _4, _5, _6)
  File "code/__torch__/adet/modeling/text_spotter.py", line 23, in forward
    batched_imgs = torch.unsqueeze_(_7, 0)
    x0 = torch.contiguous(batched_imgs)
    _8, _9, _10, _11, = (_0).forward(x0, image_size, )
                         ~~~~~~~~~~~ <--- HERE
    _12 = torch.softmax(_9, -1)
    prob = torch.sigmoid(torch.mean(_8, [-2]))
  File "code/__torch__/adet/modeling/model/detection_transformer.py", line 50, in forward
    _29 = getattr(self.input_proj, "1")
    _30 = getattr(self.input_proj, "0")
    _31 = (self.backbone).forward(x, image_size, )
           ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, = _31
    _58 = (_30).forward(_32, )
  File "code/__torch__/adet/modeling/text_spotter.py", line 104, in forward
    image_size: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
    _61 = getattr(self, "1")
    _62 = (getattr(self, "0")).forward(x, image_size, )
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _63, _64, _65, _66, _67, _68, _69, = _62
    pos_embed = torch.to((_61).forward(_63, ), 6)
  File "code/__torch__/adet/modeling/text_spotter.py", line 143, in forward
    _92 = torch.slice(torch.slice(_91, 0, 0, 125), 1, 0, 138)
    _93 = torch.view(CONSTANTS.c2, annotate(List[int], []))
    _94 = torch.copy_(_92, torch.expand(_93, [125, 138]))
          ~~~~~~~~~~~ <--- HERE
    masks_per_feature_level0 = torch.ones([_85, _86, _87], dtype=11, layout=None, device=torch.device("cpu"), pin_memory=False)
    _95 = torch.select(masks_per_feature_level0, 0, 0)

Traceback of TorchScript, original code (most recent call last):
/root/autodl-tmp/project/adet/modeling/text_spotter.py(60): mask_out_padding
/root/autodl-tmp/project/adet/modeling/text_spotter.py(43): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/adet/modeling/text_spotter.py(21): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/adet/modeling/model/detection_transformer.py(168): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/autodl-tmp/project/adet/modeling/text_spotter.py(220): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/detectron2/export/flatten.py(259): <lambda>
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/detectron2/export/flatten.py(294): forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1039): _slow_forward
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/nn/modules/module.py(1051): _call_impl
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/jit/_trace.py(952): trace_module
/root/miniconda3/envs/deepsolo/lib/python3.8/site-packages/torch/jit/_trace.py(735): trace
/root/autodl-tmp/project/deploy/export_model.py(125): export_tracing
/root/autodl-tmp/project/deploy/export_model.py(224): <module>
/root/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py(18): execfile
/root/.pycharm_helpers/pydev/pydevd.py(1496): _exec
/root/.pycharm_helpers/pydev/pydevd.py(1489): run
/root/.pycharm_helpers/pydev/pydevd.py(2177): main
/root/.pycharm_helpers/pydev/pydevd.py(2195): <module>
RuntimeError: The size of tensor a (50) must match the size of tensor b (125) at non-singleton dimension 0

That’s all the informations I can tell you, and looking forward to your reply!!
Thank you very much!!!

TorchScript is in maintenance mode and I doubt it will get any major fixes at this time.
You could try to torch.compile the model and see if this would work.

but my model is based on Detectron2, which hasn’t been updated in over a year. The current v0.6 release only supports CUDA11.3 + PyTorch 1.10.0 at most.
So you mean it’s possible that torchscript doesn’t support exporting my model?
I was confused because my model had been successfully exported to torchscript via trace, but I just couldn’t inference on all the images,only can inference on the image that i used for exporting the model.

Tracing the model will bake in all conditions and won’t record them (as would be the case while scripting the model). It’s unclear to me if the shape mismatch is caused by exactly this limitation, i.e. if the lack of scripting dynamic behavior of the model now causes the shape mismatch, or where exactly it’s coming from.
My response explaining that TorchScript is in maintenance mode was just a warning that you might not expect to see any fixes in it and since you’ve tagged me.