Hi,
I have updated the PytorchMobile for my Android application from 1.7.1 to 1.8.0 but doing so caused an error in the inference phase. Somehow the tensor dimensions became not valid anymore for PixelShuffle layer. Here is the error:
Process: com.example.pytorchtutorial, PID: 14078
java.lang.RuntimeException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__/CARN_SR/carn_x2/___torch_mangle_86.py", line 130, in forward
_62 = getattr(self, "prepack_folding._jit_pass_packed_weight_32")
_63 = ops.prepacked.conv2d_clamp_run(_61, _62)
out8 = torch.pixel_shuffle(_63, 2)
~~~~~~~~~~~~~~~~~~~ <--- HERE
_64 = getattr(self, "prepack_folding._jit_pass_packed_weight_33")
out9 = ops.prepacked.conv2d_clamp_run(out8, _64)
Traceback of TorchScript, original code (most recent call last):
File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/pixelshuffle.py", line 46, in forward
def forward(self, input: Tensor) -> Tensor:
return F.pixel_shuffle(input, self.upscale_factor)
~~~~~~~~~~~~~~~ <--- HERE
RuntimeError: Dimension out of range (expected to be in range of [-6, 5], but got 4294967290)
at org.pytorch.NativePeer.forward(Native Method)
at org.pytorch.Module.forward(Module.java:49)
at com.example.pytorchtutorial.MainActivity.applySR(MainActivity.java:122)
at com.example.pytorchtutorial.MainActivity$2.onClick(MainActivity.java:214)
at android.view.View.performClick(View.java:7448)
at android.view.View.performClickInternal(View.java:7425)
at android.view.View.access$3600(View.java:810)
at android.view.View$PerformClick.run(View.java:28305)
at android.os.Handler.handleCallback(Handler.java:938)
at android.os.Handler.dispatchMessage(Handler.java:99)
at android.os.Looper.loop(Looper.java:223)
at android.app.ActivityThread.main(ActivityThread.java:7656)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:592)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:947)
When I downgrade the Pytorch Mobile version to 1.7.1 the error disappears. I think the problem is caused in PixelShuffle layer because when I use another model for inference, the error is thrown in that layer as well. For example:
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__/Models/ESPCN/___torch_mangle_163.py", line 29, in forward
_10 = (_4).forward(_8, )
_11 = (_3).forward()
_12 = (_0).forward((_1).forward((_2).forward(_10, ), ), )
~~~~~~~~~~~ <--- HERE
_13 = torch.contiguous(_12, memory_format=0)
return _13
File "code/__torch__/torch/nn/modules/pixelshuffle/___torch_mangle_160.py", line 8, in forward
def forward(self: __torch__.torch.nn.modules.pixelshuffle.___torch_mangle_160.PixelShuffle,
argument_1: Tensor) -> Tensor:
return torch.pixel_shuffle(argument_1, 2)
~~~~~~~~~~~~~~~~~~~ <--- HERE
I would appreciate any help with this. Thanks
EDIT: This issue is now on Github