I have implemented a UNet-like segmentation model based on MobileNetv3 to do binary segmentation on an Android device. I trained the model on desktop and exported it for PyTorch Mobile by tracing. Its accuracy is as expected.
However, its latency is inconsistent across two different applications on the same device. I have written one app for benchmarking different models by repeatedly feeding them certain inputs, as well as the app that actually uses the models for segmentation (getting Images using CameraX API, segmentation, displaying result).
In both applications, I measure the time used for inference:
task_start_time = System.currentTimeMillis();
outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
inference_time = System.currentTimeMillis() - task_start_time;
However, the model performs better within my Benchmark App (~27ms/inference in 1000 runs) than within the actual application (~45ms/inference), despite using the exact same code. So my question is:
Which factors might slow down the inference step in one application but not in another?
In my Benchmark App, I do the following:
- Load the model.
- Create a zero array, a random array, or an array from an image (either a single time or once per inference). Values within the array do not seem to affect performance in any case.
- Use the float array to construct the input tensor
- Do inference
- Get results using
outputTensor.getDataAsFloatArray()
- Repeat steps 2-5 n times.
In the actual application, I get an image from the camera, do some pre-processing resulting in a float array of normalized RGB values, pass this to the model for inference, get the results and use them in a post-processing step. Of course this takes longer per iteration, but I would expect the inference time to be consistent across both apps. This is not the case, even if I remove all pre- and post-processing steps and initialize the input Tensor with fixed values at startup…
I am using org.pytorch:pytorch_android:1.7.0-SNAPSHOT
in both applications. I did not change any settings in either one (number of threads e.g). Any guidance is appreciated. Please let me know if I need to provide more implementation details.