Issue on FasterRCNN JIT Tracing

Hi. I’m trying to trace FasterRCNN model. Here’s the code.

import torch, torchvision
import numpy as np

def dict_to_tuple(d):
    return d["boxes"], d["scores"], d["labels"]

class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input):
        outputs = self.model(input)
        return dict_to_tuple(outputs[0])

model = torchvision.models.detection.fasterrcnn_resnet50_fpn
model = TraceWrapper(model(pretrained=True))
model.eval()

batch_size = 1
channel = 3
img_width = 1024
img_height = 1024

input = torch.Tensor(np.random.uniform(0.0, 250.0, size=(batch_size, channel, img_height, img_width)))

with torch.no_grad():  
    scripted_model = torch.jit.trace(model, input)
    scripted_model.eval()

However, I’m encourting some issues. Here’s the error message I got at line “torch.jit.trace”:

Traceback (most recent call last):
  File "/home/plg/yyz_workspace/tvm_DOTA/pass_reset_input.py", line 31, in <module>
    scripted_model = torch.jit.trace(model, input)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/jit/_trace.py", line 794, in trace
    return trace_module(
           ^^^^^^^^^^^^^
  File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/jit/_trace.py", line 1084, in trace_module
    _check_trace(
  File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/jit/_trace.py", line 562, in _check_trace
    raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
        Graph diff:
                  graph(%self.1 : __torch__.TraceWrapper,
                        %input.1 : Tensor):
                    %model : __torch__.torchvision.models.detection.faster_rcnn.FasterRCNN = prim::GetAttr[name="model"](%self.1)
                    %8 : float = prim::Constant[value=0.03125](), scope: __module.model/__module.model.roi_heads/__module.model.roi_heads.box_roi_pool # /home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/_ops.py:502:0
                    %9 : float = prim::Constant[value=0.0625](), scope: __module.model/__module.model.roi_heads/__module.model.roi_heads.box_roi_pool # /home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/_ops.py:502:0
                    %10 : float = prim::Constant[value=0.125](), scope: __module.model/__module.model.roi_heads/__module.model.roi_heads.box_roi_pool # /home/plg/miniconda3/envs/vision_dota/lib/python3.11/site-packages/torch/_ops.py:502:0
                ......  // Thousands of lines omitted
                +   %2948 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%2947, %2921, %2922)
                ?      ^^                                                      ^^      ^     ^^
                -   %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::TupleUnpack(%2956)
                ?                                                                ^^
                +   %4 : Tensor, %5 : Tensor, %6 : Tensor = prim::TupleUnpack(%2948)
                ?                                                                ^^
                    %7 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%4, %5, %6)
                    return (%7)
        First diverging operator:
        Node diff:
                - %model : __torch__.torchvision.models.detection.faster_rcnn.FasterRCNN = prim::GetAttr[name="model"](%self.1)
                + %model : __torch__.torchvision.models.detection.faster_rcnn.___torch_mangle_355.FasterRCNN = prim::GetAttr[name="model"](%self.1)
                ?

I’ve tried the method of changing “torch.jit.trace” to “torch.jit.script”, but new errors occurred and the script method doesn’t seem to meet my needs. What’s the error about and what should I do to solve the problem? Thanks for your help!

Could you check if torch.compile would work as TorchScript is in maintenance mode?

Thanks for your reply! Do you mean that if torch.compile(model) works? Or add the compiling somewhere in the tracing process? torch.compile(model) itself works.

I meant the former, i.e. to remove scripting and to try to use torch.compile only.

Actually I’m trying to transfer this pytorch model to a TVM Relay model, and TVM requires the scripted model to do its job, according to TVM’s doc.

Compile PyTorch Models — tvm 0.15.dev0 documentation

So I’m wondering if there’s any other way to get the scripted model. If not, I’ll try other methods. Thanks!