Hello, I have been having trouble deploying my pytorch model (which works fine on my python inference app) to android using Android Studio (Kotlin).
I have trained this CNN classifier:
class FoodClassifier(nn.Module):
def __init__(self, n_classes, freeze_backbone=False):
super(FoodClassifier, self).__init__()
effnet = efficientnet_b2(pretrained=True)
self.vision = nn.Sequential(*(list(effnet.children())[:-1]))
if freeze_backbone:
for param in self.vision.parameters():
param.requires_grad = False
print("[~] Vision Backbone is frozen")
self.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(1408, n_classes, bias=True)
)
def forward(self, x):
x = self.vision(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
Then I exported it to ptl and onnx using the following scripts:
#!/usr/bin/env python3
import torch
import numpy as np
from torch.utils.mobile_optimizer import optimize_for_mobile
from model import *
from utils import *
MODEL_PATH = "models/FoodClassifier.pt"
print("[+] Model path:", MODEL_PATH)
if __name__ == "__main__":
dataset = FoodDataset(BASE_DIR)
model = FoodClassifier(len(dataset.classes))
model = load_model(MODEL_PATH, model)
model.eval()
print(model)
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
out_path = MODEL_PATH.split(".")[0] + ".ptl"
traced_script_module_optimized._save_for_lite_interpreter(out_path)
print("Model saved at:", out_path)
#!/usr/bin/env python3
import torch
import numpy as np
from model import *
from utils import *
MODEL_PATH = "models/FoodClassifier.pt"
print("[+] Model path:", MODEL_PATH)
if __name__ == "__main__":
dataset = FoodDataset(BASE_DIR)
model = FoodClassifier(len(dataset.classes))
model = load_model(MODEL_PATH, model)
model.eval()
print(model)
t_input = torch.randn(1, 3, 224, 224)
out_path = MODEL_PATH.split(".")[0] + ".onnx"
torch.onnx.export(model, t_input, out_path, input_names=["image"], output_names=["food_class"], verbose=True)
print(f"[+] Exported {MODEL_PATH} at {out_path}")
These all have seemed right since my inference app can recognize a pizza from a jpg file:
#!/usr/bin/env python3
import os
import cv2
import numpy as np
import onnx
import onnxruntime
from PIL import Image
import torchvision.transforms as transforms
from model import *
BASE_DIR = "./data/"
W = H = 224
IDX = 30000 # pizza
MODEL_PATH = "models/FoodClassifier.pt"
MODEL_PATH_ONNX = "models/FoodClassifier.onnx"
print("[+] Model path:", MODEL_PATH)
def test_torch(img, label, classes):
with torch.no_grad():
model = FoodClassifier(len(classes))
model = load_model(MODEL_PATH, model)
model.eval()
img = img.unsqueeze(0).float()
print(img.shape)
print(img)
out = model(img)
print(out)
cat = np.argmax(out)
print(cat)
print("Prediction: ", classes[cat])
def test_onnx(img, label, classes):
onnx_model = onnx.load(MODEL_PATH_ONNX)
onnx.checker.check_model(onnx_model)
img = img.unsqueeze(0).float()
print(img)
session = onnxruntime.InferenceSession(MODEL_PATH_ONNX)
ort_inputs = {session.get_inputs()[0].name: img.detach().numpy()}
ort_outputs = session.run(None, ort_inputs)[0]
print(ort_outputs)
print(ort_outputs.shape)
cat = np.argmax(ort_outputs)
print(cat)
print("Prediction: ", classes[cat])
return cat, classes[cat]
if __name__ == "__main__":
# ...
transform = transforms.Compose([transforms.Resize((W, H)), transforms.ToTensor()])
img = Image.open(image_files[IDX][1])
image = transform(img)
data = {"image": image, "label": np.array(image_files[IDX][0]), "path": image_files[IDX][1]}
image, label, path = data["image"], data["label"], data["path"]
test_torch(image, label, classes)
pred, label = test_onnx(image, label, classes)
the input tensor is like this:
torch.Size([1, 3, 224, 224])
tensor([[[[0.7922, 0.8039, 0.8039, ..., 0.4392, 0.5412, 0.5020],
[0.8078, 0.8118, 0.8157, ..., 0.4667, 0.5098, 0.5059],
[0.8157, 0.8157, 0.8235, ..., 0.4902, 0.5176, 0.4941],
...,
[0.9569, 0.9569, 0.9529, ..., 0.8824, 0.8745, 0.8706],
[0.9569, 0.9569, 0.9529, ..., 0.8745, 0.8706, 0.8667],
[0.9529, 0.9490, 0.9490, ..., 0.8627, 0.8588, 0.8549]],
[[0.7843, 0.7961, 0.7961, ..., 0.4431, 0.5373, 0.4863],
[0.8000, 0.8039, 0.8078, ..., 0.4667, 0.5098, 0.4941],
[0.8078, 0.8078, 0.8157, ..., 0.4706, 0.5059, 0.4745],
...,
[0.9569, 0.9569, 0.9529, ..., 0.8824, 0.8784, 0.8745],
[0.9529, 0.9529, 0.9490, ..., 0.8784, 0.8745, 0.8706],
[0.9490, 0.9451, 0.9451, ..., 0.8784, 0.8745, 0.8706]],
[[0.7961, 0.8078, 0.8078, ..., 0.3961, 0.5255, 0.4667],
[0.8118, 0.8157, 0.8196, ..., 0.3961, 0.4510, 0.4549],
[0.8196, 0.8196, 0.8275, ..., 0.4667, 0.4627, 0.4549],
...,
[0.9569, 0.9490, 0.9451, ..., 0.9020, 0.8941, 0.8863],
[0.9451, 0.9451, 0.9373, ..., 0.8902, 0.8824, 0.8784],
[0.9373, 0.9333, 0.9294, ..., 0.8863, 0.8745, 0.8667]]]])
while both pt and onnx models have the same output: argmax = 17 (pizza),
although the values of pt model are something like:
tensor([[ -6.4250, -5.6502, -9.1750, -6.7315, -6.1992, -4.3728, -7.5658,
and onnx’s:
tensor([[[[0.7922, 0.8039, 0.8039, ..., 0.4392, 0.5412, 0.5020],
which I don’t think plays a big role
Given that, I created an inference app on android (for both onnx and ptl):
CameraActivity.kt:
fun startCNNModel(imagePath: String) {
try {
Log.e("CNN", "Hello from CNN");
var foodDetector = FoodDetector()
var imageStream = assets.open("model/pizza.jpg") // TODO: replace with imagePath
// Load onnx model
var modelPath = assetFilePath(this, "FoodClassifier.onnx")
var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
var ortSession: OrtSession = ortEnv.createSession(modelPath);
var model_out = foodDetector.detectOnnx(imageStream, ortEnv, ortSession)
// torchscript
// val module = LiteModuleLoader.load(assetFilePath(this, "FoodClassifier.ptl"))
// var model_out = foodDetector.detect(imageStream, module)
val foodNamesModel = arrayOf("pizza", "pizza", "pizza")
lifecycleScope.launch(Dispatchers.IO) {
val foodNames = foodNamesModel.asList().joinToString(",").split(",").toTypedArray()
// Translate the foodNames to the correct locale if necessary
val translatedFoodNames = if (Locale.getDefault().language != "en") {
cameraViewModel.translateFoodNames(foodNames)
} else {
foodNames
}
runOnUiThread {
binding.cameraProgressCircle.hide()
showPredictionCheckDialog(translatedFoodNames)
}
}
} catch (e: Exception) {
Log.e("Inference", "Error running inference", e)
}
FoodDetector.kt (onnx):
fun preprocessImageOnnx(inputStream: InputStream, ortEnv: OrtEnvironment): OnnxTensor {
val inputWidth = 224
val inputHeight = 224
// Resize the bitmap
val bitmap = BitmapFactory.decodeStream(inputStream)
// Resize the image
val resizedBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, true)
// Convert the bitmap to a float array
val floatArray = FloatArray(resizedBitmap.width * resizedBitmap.height * 3) // 3 channels for RGB
for (y in 0 until resizedBitmap.height) {
for (x in 0 until resizedBitmap.width) {
val pixel = resizedBitmap.getPixel(x, y)
val red = Color.red(pixel) / 255.0f
val green = Color.green(pixel) / 255.0f
val blue = Color.blue(pixel) / 255.0f
// Assuming RGB order, you may need to adapt this to match your image format
val index = (y * resizedBitmap.width + x) * 3
floatArray[index] = red
floatArray[index + 1] = green
floatArray[index + 2] = blue
}
}
println(floatArray.contentToString())
val byteBuffer = ByteBuffer.allocateDirect(floatArray.size * 4).apply {
asFloatBuffer().put(floatArray)
}
val inputTensor = OnnxTensor.createTensor(
ortEnv,
byteBuffer,
longArrayOf(1, 3, inputHeight.toLong(), inputWidth.toLong()),
OnnxJavaType.FLOAT
)
return inputTensor
}
fun detectOnnx(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): FloatArray {
val inputTensor = preprocessImageOnnx(inputStream, ortEnv)
inputTensor.use {
// Step 4: Call ONNX inferenceSession run
val output = ortSession.run(
Collections.singletonMap("image", inputTensor),
setOf("food_class")
)
Log.d("CNN out", output.toString())
output.use {
val classificationOutput = (output?.get(0)?.value) as Array<FloatArray>
// Assuming classificationOutput is of shape (1, 256)
val probabilities = classificationOutput[0]
println(probabilities.contentToString())
Log.e("Output Class", argmax(probabilities).toString())
Log.e("Output Class", topK(probabilities, 5).contentToString())
return probabilities
}
}
}
FoodDetector.kt (ptl):
fun preprocessImage(inputStream: InputStream): Tensor {
//Decode the input stream into a bitmap
val bitmap = BitmapFactory.decodeStream(inputStream)
// Resize the image
val resizedBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, true)
// val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
// resizedBitmap,
// TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
// TensorImageUtils.TORCHVISION_NORM_STD_RGB
// )
val floatArray = FloatArray(resizedBitmap.width * resizedBitmap.height * 3) // 3 channels for RGB
for (y in 0 until resizedBitmap.height) {
for (x in 0 until resizedBitmap.width) {
val pixel = resizedBitmap.getPixel(x, y)
val red = Color.red(pixel) / 255.0f
val green = Color.green(pixel) / 255.0f
val blue = Color.blue(pixel) / 255.0f
// Assuming RGB order, you may need to adapt this to match your image format
val index = (y * resizedBitmap.width + x) * 3
floatArray[index] = red
floatArray[index + 1] = green
floatArray[index + 2] = blue
}
}
println(floatArray.contentToString())
println(floatArray.size)
val inputTensor = Tensor.fromBlob(floatArray, longArrayOf(1, 3, IMAGE_WIDTH.toLong(), IMAGE_HEIGHT.toLong()))
println(inputTensor.dataAsFloatArray.contentToString())
println(inputTensor.shape().contentToString())
inputStream.close()
return inputTensor
}
fun detect(inputStream: InputStream, module: Module): FloatArray{
val inputTensor = preprocessImage(inputStream)
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
val outputArray = outputTensor.dataAsFloatArray
println(outputArray.contentToString())
Log.e("Output Class", argmax(outputArray).toString())
Log.e("Output Class", topK(outputArray, 5).contentToString())
// TODO: use class names and return top 5 names
return outputArray
}
Note that I have already tried the classic approach in the tutorials:
val inputStream: InputStream = assetManager.open(imageName)
BitmapFactory.decodeStream(inputStream)
...
val tensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB.,
TensorImageUtils.TORCHVISION_NORM_STD_RGB)
And after a long set of experimenting, I haven’t gotten anywhere. The input tensor values on android are these:
[0.7882353, 0.78039217, 0.7921569, 0.8039216, 0.79607844, 0.80784315, 0.8039216, 0.79607844, 0.80784315, 0.8117647, 0.8039216, 0.8156863, 0.8156863, 0.80784315 ....
Which are slightly different than the ones PIL loads from pizza.jpg.
The onnx output tensor is this:
[NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN,
and the pytorch lite’s is this:
[-4563.728, -14917.566, -14551.325, -16742.281, -10390.229, -11888.001, -15174.114, -13337.571, ...
argmax: 86
topK: [86, 0, 35, 16, 19]
which leads to a completely wrong prediction.
There isn’t much documentation on this, even though a lot of people have had issues deploying on Android in the past. Any ideas why this is happening? Since using the documented implementations was unsuccessful, what is the right way to deploy this model to Android?