Hi, I was following this example: https://pytorch.org/mobile/android/ except that I am using FastRCNN net with FASTRCNNPredictor. (source code from trace_model.py below)
The first problem was that the output of FastRCNN is a list of dictionaries of tensors and torch.jit.trace takes output only as a list of tensors or a tuple of tensors. But that’s okay. I’ve tracked the error and looked into generilized_rcnn.py and changed the method eager_outputs in:
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
"""
Arguments:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1]))
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
proposals, proposal_losses = self.rpn(images, features, targets)
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
self._has_warned = True
return (losses, detections)
else:
return self.eager_outputs(losses, detections)
from
def eager_outputs(self, losses, detections):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
if self.training:
return losses
return detections
to
def eager_outputs(self, losses, detections):
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
if self.training:
return losses
tensor_output = list()
for dictionary in detections:
for key in dictionary:
tensor_output.append(dictionary[key])
return tensor_output
and I was able to run trace_model.py to generate .pt file.
Then I used the HelloWorldApp to check out if everything’s compiling and if I can load the model in java. So I just changed the model.pt to my model generated from trace_model.py on line 42 and I’ve runned this application: https://github.com/pytorch/android-demo-app/tree/master/HelloWorldApp/app/src/main/java/org/pytorch/helloworld.
Gradle build was successful yet I keep getting errors that I can’t cope with. I am almost sure the problem is with the module = Module.load(assetFilePath(this, “model_actual.pt”)); part since I’ve commented the rest and it was fine.
2020-06-13 01:16:33.541 30594-30594/org.pytorch.helloworld E/AndroidRuntime: FATAL EXCEPTION: main
Process: org.pytorch.helloworld, PID: 30594
java.lang.RuntimeException: Unable to start activity ComponentInfo{org.pytorch.helloworld/org.pytorch.helloworld.MainActivity}: com.facebook.jni.CppException:
Arguments for call are not valid.
The following variants are available:
aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners) -> (Tensor):
Expected at most 3 arguments but found 5 positional arguments.
aten::upsample_bilinear2d.out(Tensor self, int[2] output_size, bool align_corners, *, Tensor(a!) out) -> (Tensor(a!)):
Argument out not provided.
The original call is:
C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\nn\functional.py(3013): interpolate
C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torchvision\models\detection\transform.py(92): resize
C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torchvision\models\detection\transform.py(45): forward
C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py(534): _slow_forward
C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py(548): __call__
C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torchvision\models\detection\generalized_rcnn.py(72): forward
C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py(534): _slow_forward
C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py(548): __call__
C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\jit\__init__.py(1027): trace_module
C:\Users\patry\AppData\Roaming\Python\Python37\site-packages\torch\jit\__init__.py(875): trace
d:\AndroidREPO\HelloWorldApp\android-demo-app\HelloWorldApp\trace_model.py(17): <module>
Serialized File "code/__torch__/torchvision/models/detection/transform.py", line 24
_11 = torch.mul(torch.to(_9, 6, False, False, None), torch.detach(_10))
_12 = torch.to(_11, 6, False, False, None)
_13 = torch.upsample_bilinear2d(input, [int(_8), int(torch.floor(_12))], False, None, None)**
~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE**
img = torch.select(_13, 0, 0)**
height = ops.prim.NumToTensor(torch.size(img, 1))**
at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:2946)
at android.app.ActivityThread.handleLaunchActivity(ActivityThread.java:3081)
at android.app.servertransaction.LaunchActivityItem.execute(LaunchActivityItem.java:78)
at android.app.servertransaction.TransactionExecutor.executeCallbacks(TransactionExecutor.java:108)
at android.app.servertransaction.TransactionExecutor.execute(TransactionExecutor.java:68)
at android.app.ActivityThread$H.handleMessage(ActivityThread.java:1831)
at android.os.Handler.dispatchMessage(Handler.java:106)
at android.os.Looper.loop(Looper.java:201)
at android.app.ActivityThread.main(ActivityThread.java:6810)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:547)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:873)
Caused by: com.facebook.jni.CppException:
Arguments for call are not valid.
The following variants are available:
aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners) -> (Tensor):
Expected at most 3 arguments but found 5 positional arguments.
aten::upsample_bilinear2d.out(Tensor self, int[2] output_size, bool align_corners, *, Tensor(a!) out) -> (Tensor(a!)):**
Argument out not provided.**
And my trace_model.py code:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
example = torch.rand(1, 3, 800, 800)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
model.eval()
PATH = 'model_weights.pth'
model.load_state_dict(torch.load(PATH))
traced_script_module = torch.jit.trace(model, example, check_trace = False)
traced_script_module.save("app/src/main/assets/model_actual.pt")
I would be very grateful if anyone gave me a hint on what could cause a trouble here. I’ve went through hundreds of PyTorch’s code but I couldn’t find anything. I’m still a beginner at programming. Thank you.