Hi everyone. I’m deploying a deeplab model onto an android application. The output of the model is a [1,21, 400,400] tensor. In order to get my final result mask, I have to run an argmax operation on the 2nd dimension. In python I figure the operation would be something like this out = np.argmax(out, axis=1). Then I’ll have to create a mask based on the result (out[out==classId] = colorValue).
But there is nothing quite like numpy on android. Looping through the whole thing to find all the maximum values for each pixel, and then set the color is going to be very slow. Is there any way I can do this efficiently. Any help would be greatly appreciated
The easiest probably is to code this in TorchScript and call that.
You can also run libtorch proper on PyTorch and do your own JNI wrapper for whatever you need.
Thanks for the reply. I’m not all that familiar with TorchScript and Android. What would the process be like? Would you mind recommending me a starting place?
- You write your numpy-like processing in PyTorch as a plain python function or better yet as a nn.Module,
- run it through torch.jit.script similar to the TorchScript tutorials,
- now you can use it similar to the Android tutorials.
Using a nn.Module in step 1 isn’t better in the sense that it would be preferable per se, but it will reduce the amount of adaptation you need to do in steps 2 and 3.