PyTorch Android cannot load model using trace_script module

Hi, I was following this example: https://pytorch.org/mobile/android/ except that I am using FastRCNN net with FASTRCNNPredictor. (source code from trace_model.py below)

The first problem was that the output of FastRCNN is a list of dictionaries of tensors and torch.jit.trace takes output only as a list of tensors or a tuple of tensors. But that’s okay. I’ve tracked the error and looked into generilized_rcnn.py and changed the method eager_outputs in:

def forward(self, images, targets=None):
        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
        """
        Arguments:
            images (list[Tensor]): images to be processed
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).

        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
        original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

        images, targets = self.transform(images, targets)
        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([('0', features)])
        proposals, proposal_losses = self.rpn(images, features, targets)
        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
            return (losses, detections)
        else:
            return self.eager_outputs(losses, detections)

from

def eager_outputs(self, losses, detections):
        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        if self.training:
            return losses

        return detections

to

 def eager_outputs(self, losses, detections):
        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        if self.training:
            return losses
        tensor_output = list()
        
        for dictionary in detections:
            for key in dictionary:
                tensor_output.append(dictionary[key])
        
        
            return tensor_output

and I was able to run trace_model.py to generate .pt file.

Then I used the HelloWorldApp to check out if everything’s compiling and if I can load the model in java. So I just changed the model.pt to my model generated from trace_model.py on line 42 and I’ve runned this application: https://github.com/pytorch/android-demo-app/tree/master/HelloWorldApp/app/src/main/java/org/pytorch/helloworld.

Gradle build was successful yet I keep getting errors that I can’t cope with. I am almost sure the problem is with the module = Module.load(assetFilePath(this, “model_actual.pt”)); part since I’ve commented the rest and it was fine.

2020-06-13 01:16:33.541 30594-30594/org.pytorch.helloworld E/AndroidRuntime: FATAL EXCEPTION: main
    Process: org.pytorch.helloworld, PID: 30594
    java.lang.RuntimeException: Unable to start activity ComponentInfo{org.pytorch.helloworld/org.pytorch.helloworld.MainActivity}: com.facebook.jni.CppException: 
    Arguments for call are not valid.
    The following variants are available:
      
      aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners) -> (Tensor):
      Expected at most 3 arguments but found 5 positional arguments.
      
      aten::upsample_bilinear2d.out(Tensor self, int[2] output_size, bool align_corners, *, Tensor(a!) out) -> (Tensor(a!)):
      Argument out not provided.
    
    The original call is:
    C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\nn\functional.py(3013): interpolate
    C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torchvision\models\detection\transform.py(92): resize
    C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torchvision\models\detection\transform.py(45): forward
    C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py(534): _slow_forward
    C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py(548): __call__
    C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torchvision\models\detection\generalized_rcnn.py(72): forward
    C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py(534): _slow_forward
    C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py(548): __call__
    C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\jit\__init__.py(1027): trace_module
    C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\jit\__init__.py(875): trace
    d:\AndroidREPO\HelloWorldApp\android-demo-app\HelloWorldApp\trace_model.py(17): <module>
    Serialized   File "code/__torch__/torchvision/models/detection/transform.py", line 24
        _11 = torch.mul(torch.to(_9, 6, False, False, None), torch.detach(_10))
        _12 = torch.to(_11, 6, False, False, None)
        _13 = torch.upsample_bilinear2d(input, [int(_8), int(torch.floor(_12))], False, None, None)**
              ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE**
        img = torch.select(_13, 0, 0)**
        height = ops.prim.NumToTensor(torch.size(img, 1))**
    
        at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:2946)
        at android.app.ActivityThread.handleLaunchActivity(ActivityThread.java:3081)
        at android.app.servertransaction.LaunchActivityItem.execute(LaunchActivityItem.java:78)
        at android.app.servertransaction.TransactionExecutor.executeCallbacks(TransactionExecutor.java:108)
        at android.app.servertransaction.TransactionExecutor.execute(TransactionExecutor.java:68)
        at android.app.ActivityThread$H.handleMessage(ActivityThread.java:1831)
        at android.os.Handler.dispatchMessage(Handler.java:106)
        at android.os.Looper.loop(Looper.java:201)
        at android.app.ActivityThread.main(ActivityThread.java:6810)
        at java.lang.reflect.Method.invoke(Native Method)
        at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:547)
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:873)
     Caused by: com.facebook.jni.CppException: 
    Arguments for call are not valid.
    The following variants are available:
      
      aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners) -> (Tensor):
      Expected at most 3 arguments but found 5 positional arguments.
      
      aten::upsample_bilinear2d.out(Tensor self, int[2] output_size, bool align_corners, *, Tensor(a!) out) -> (Tensor(a!)):**
      Argument out not provided.**

And my trace_model.py code:

import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

example = torch.rand(1, 3, 800, 800)
in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)

model.eval()
PATH = 'model_weights.pth'
model.load_state_dict(torch.load(PATH))

traced_script_module = torch.jit.trace(model, example, check_trace = False)

traced_script_module.save("app/src/main/assets/model_actual.pt")

I would be very grateful if anyone gave me a hint on what could cause a trouble here. I’ve went through hundreds of PyTorch’s code but I couldn’t find anything. I’m still a beginner at programming. Thank you.

Hi!

I see that you have check_trace = False. There is a reason for that? Also, did you check traced module in pytorch? As it returns a different number of bboxes and trace doesn’t support control flow, it may have compiled another upsample_bilinear2d method. You can mix trace and scripting

Finally, if you face an error in pyotrch mobile, make sure that the exact same image runs well on the traced/scripted model in python to pinpoint where the bug reside (in your code, in torchscript, in mobile, etc).