Hello,
I’ve tried to use custom model on Android but forward fails with error I unable to understand.
I’ve scripted my model using TorchScript annotation method and now I am trying to preform a forward on mobile.
Model looks something like that:
config = ...
class WrapRPN(nn.Module):
def __init__(self):
super().__init__()
self.rpn = RPN(config).eval().cpu()
def forward(self, features):
# type: (Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]
mock_input : InputClass = InputClass(torch.rand((N, 320, 320)))
instances = self.rpn(mock_input, features)
output : Dict[str, torch.Tensor] = {}
for idx in range(len(instances)):
inst : Instances = instances[idx]
box_tensor : torch.Tensor = inst.proposal_boxes.tensor
output[str(idx)] = box_tensor
return output
It has been converted and loaded to mobile, but fails on runtime
E/AndroidRuntime: FATAL EXCEPTION: main
Process: org.pytorch.helloworld, PID: 20157
java.lang.RuntimeException: Unable to start activity ComponentInfo{org.pytorch.helloworld/org.pytorch.helloworld.MainActivity}: com.facebook.jni.CppException: forward() Expected a value of type 'Dict[str, Tensor]' for argument 'features' but instead found type 'Dict[str, Tensor]'.
Position: 1
Declaration: forward(ClassType<WrapRPN> self, Dict(str, Tensor) features) -> (Dict(str, Tensor)) (checkArg at ../aten/src/ATen/core/function_schema_inl.h:194)
(no backtrace available)
at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:3784)
at android.app.ActivityThread.handleLaunchActivity(ActivityThread.java:3955)
at android.app.servertransaction.LaunchActivityItem.execute(LaunchActivityItem.java:91)
at android.app.servertransaction.TransactionExecutor.executeCallbacks(TransactionExecutor.java:149)
at android.app.servertransaction.TransactionExecutor.execute(TransactionExecutor.java:103)
at android.app.ActivityThread$H.handleMessage(ActivityThread.java:2392)
at android.os.Handler.dispatchMessage(Handler.java:107)
at android.os.Looper.loop(Looper.java:213)
at android.app.ActivityThread.main(ActivityThread.java:8147)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:513)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1100)
Caused by: com.facebook.jni.CppException: forward() Expected a value of type 'Dict[str, Tensor]' for argument 'features' but instead found type 'Dict[str, Tensor]'.
Position: 1
Declaration: forward(ClassType<WrapRPN> self, Dict(str, Tensor) features) -> (Dict(str, Tensor)) (checkArg at ../aten/src/ATen/core/function_schema_inl.h:194)
(no backtrace available)
at org.pytorch.NativePeer.forward(Native Method)
at org.pytorch.Module.forward(Module.java:37)
at org.pytorch.helloworld.MainActivity.onCreate(MainActivity.java:66)
at android.app.Activity.performCreate(Activity.java:8068)
at android.app.Activity.performCreate(Activity.java:8056)
at android.app.Instrumentation.callActivityOnCreate(Instrumentation.java:1320)
at android.app.ActivityThread.performLaunchActivity(ActivityThread.java:3757)
... 11 more
Java code is following:
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
Map<String, IValue> hm = new HashMap<String, IValue>();
List<String> keys = Arrays.asList("p2", "p3", "p4", "p5", "p6");;
for (String key : keys) {
hm.put(key, IValue.from(inputTensor));
}
final IValue rpn_input = IValue.dictStringKeyFrom(hm);
module.forward(rpn_input);
What does it mean?
Expected a value of type 'Dict[str, Tensor]' for argument 'features' but instead found type 'Dict[str, Tensor]'