CNN inference behaves differently on android than on pytorch

Hello, I have been having trouble deploying my pytorch model (which works fine on my python inference app) to android using Android Studio (Kotlin).
I have trained this CNN classifier:

class FoodClassifier(nn.Module):
  def __init__(self, n_classes, freeze_backbone=False):
    super(FoodClassifier, self).__init__()

    effnet = efficientnet_b2(pretrained=True)
    self.vision = nn.Sequential(*(list(effnet.children())[:-1]))

    if freeze_backbone:
      for param in self.vision.parameters():
        param.requires_grad = False
      print("[~] Vision Backbone is frozen")

    self.classifier = nn.Sequential(
      nn.Dropout(p=0.3, inplace=True),
      nn.Linear(1408, n_classes, bias=True)
    )

  def forward(self, x):
    x = self.vision(x)
    x = x.view(x.size(0), -1)
    x = self.classifier(x)
    return x

Then I exported it to ptl and onnx using the following scripts:

#!/usr/bin/env python3
import torch
import numpy as np
from torch.utils.mobile_optimizer import optimize_for_mobile

from model import *
from utils import *

MODEL_PATH = "models/FoodClassifier.pt"
print("[+] Model path:", MODEL_PATH)


if __name__ == "__main__":
  dataset = FoodDataset(BASE_DIR)

  model = FoodClassifier(len(dataset.classes))
  model = load_model(MODEL_PATH, model)
  model.eval()
  print(model)

  example = torch.rand(1, 3, 224, 224)
  traced_script_module = torch.jit.trace(model, example)
  traced_script_module_optimized = optimize_for_mobile(traced_script_module)

  out_path = MODEL_PATH.split(".")[0] + ".ptl"
  traced_script_module_optimized._save_for_lite_interpreter(out_path)
  print("Model saved at:", out_path)
#!/usr/bin/env python3
import torch
import numpy as np

from model import *
from utils import *

MODEL_PATH = "models/FoodClassifier.pt"
print("[+] Model path:", MODEL_PATH)


if  __name__ == "__main__":
  dataset = FoodDataset(BASE_DIR)

  model = FoodClassifier(len(dataset.classes))
  model = load_model(MODEL_PATH, model)
  model.eval()
  print(model)

  t_input = torch.randn(1, 3, 224, 224)
  out_path = MODEL_PATH.split(".")[0] + ".onnx"
  torch.onnx.export(model, t_input, out_path, input_names=["image"], output_names=["food_class"], verbose=True)
  print(f"[+] Exported {MODEL_PATH} at {out_path}")

These all have seemed right since my inference app can recognize a pizza from a jpg file:

#!/usr/bin/env python3
import os
import cv2
import numpy as np
import onnx
import onnxruntime
from PIL import Image
import torchvision.transforms as transforms

from model import *

BASE_DIR = "./data/"
W = H = 224
IDX = 30000 # pizza

MODEL_PATH = "models/FoodClassifier.pt"
MODEL_PATH_ONNX = "models/FoodClassifier.onnx"
print("[+] Model path:", MODEL_PATH)

def test_torch(img, label, classes):
  with torch.no_grad():
    model = FoodClassifier(len(classes))
    model = load_model(MODEL_PATH, model)
    model.eval()

    img = img.unsqueeze(0).float()

    print(img.shape)
    print(img)
    out = model(img)
    print(out)

    cat = np.argmax(out)
    print(cat)
    print("Prediction: ", classes[cat])


def test_onnx(img, label, classes):
  onnx_model = onnx.load(MODEL_PATH_ONNX)
  onnx.checker.check_model(onnx_model)

  img = img.unsqueeze(0).float()
  print(img)

  session = onnxruntime.InferenceSession(MODEL_PATH_ONNX)
  ort_inputs = {session.get_inputs()[0].name: img.detach().numpy()}
  ort_outputs = session.run(None, ort_inputs)[0]

  print(ort_outputs)
  print(ort_outputs.shape)
  cat = np.argmax(ort_outputs)
  print(cat)
  print("Prediction: ", classes[cat])
  return cat, classes[cat]


if __name__ == "__main__":
# ...
  transform = transforms.Compose([transforms.Resize((W, H)), transforms.ToTensor()])

  img = Image.open(image_files[IDX][1])
  image = transform(img)
  data = {"image": image, "label": np.array(image_files[IDX][0]), "path": image_files[IDX][1]}

  image, label, path = data["image"], data["label"], data["path"]
  test_torch(image, label, classes)
  pred, label = test_onnx(image, label, classes)

the input tensor is like this:

torch.Size([1, 3, 224, 224])                                                                                                           
tensor([[[[0.7922, 0.8039, 0.8039,  ..., 0.4392, 0.5412, 0.5020],                                                                      
          [0.8078, 0.8118, 0.8157,  ..., 0.4667, 0.5098, 0.5059],                                                                      
          [0.8157, 0.8157, 0.8235,  ..., 0.4902, 0.5176, 0.4941],                                                                      
          ...,                                                                                                                         
          [0.9569, 0.9569, 0.9529,  ..., 0.8824, 0.8745, 0.8706],                                                                      
          [0.9569, 0.9569, 0.9529,  ..., 0.8745, 0.8706, 0.8667],                                                                      
          [0.9529, 0.9490, 0.9490,  ..., 0.8627, 0.8588, 0.8549]],                                                                     
                                                                                                                                       
         [[0.7843, 0.7961, 0.7961,  ..., 0.4431, 0.5373, 0.4863],                                                                      
          [0.8000, 0.8039, 0.8078,  ..., 0.4667, 0.5098, 0.4941],                                                                      
          [0.8078, 0.8078, 0.8157,  ..., 0.4706, 0.5059, 0.4745],                                                                      
          ...,                                                                                                                         
          [0.9569, 0.9569, 0.9529,  ..., 0.8824, 0.8784, 0.8745],                                                                      
          [0.9529, 0.9529, 0.9490,  ..., 0.8784, 0.8745, 0.8706],                                                                      
          [0.9490, 0.9451, 0.9451,  ..., 0.8784, 0.8745, 0.8706]],                                                                     
                                                                                                                                       
         [[0.7961, 0.8078, 0.8078,  ..., 0.3961, 0.5255, 0.4667],                                                                      
          [0.8118, 0.8157, 0.8196,  ..., 0.3961, 0.4510, 0.4549],                                                                      
          [0.8196, 0.8196, 0.8275,  ..., 0.4667, 0.4627, 0.4549],                                                                      
          ...,                                                                                                                         
          [0.9569, 0.9490, 0.9451,  ..., 0.9020, 0.8941, 0.8863],                                                                      
          [0.9451, 0.9451, 0.9373,  ..., 0.8902, 0.8824, 0.8784],                                                                      
          [0.9373, 0.9333, 0.9294,  ..., 0.8863, 0.8745, 0.8667]]]]) 

while both pt and onnx models have the same output: argmax = 17 (pizza),
although the values of pt model are something like:

tensor([[ -6.4250,  -5.6502,  -9.1750,  -6.7315,  -6.1992,  -4.3728,  -7.5658,

and onnx’s:

tensor([[[[0.7922, 0.8039, 0.8039,  ..., 0.4392, 0.5412, 0.5020],

which I don’t think plays a big role
Given that, I created an inference app on android (for both onnx and ptl):
CameraActivity.kt:

fun startCNNModel(imagePath: String) {
        try {
            Log.e("CNN", "Hello from CNN");

            var foodDetector = FoodDetector()
            var imageStream = assets.open("model/pizza.jpg")  // TODO: replace with imagePath

            // Load onnx model
            var modelPath = assetFilePath(this, "FoodClassifier.onnx")
            var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
            var ortSession: OrtSession = ortEnv.createSession(modelPath);
            var model_out = foodDetector.detectOnnx(imageStream, ortEnv, ortSession)

            // torchscript
//            val module = LiteModuleLoader.load(assetFilePath(this, "FoodClassifier.ptl"))
//            var model_out = foodDetector.detect(imageStream, module)

            val foodNamesModel = arrayOf("pizza", "pizza", "pizza")

            lifecycleScope.launch(Dispatchers.IO) {
                val foodNames = foodNamesModel.asList().joinToString(",").split(",").toTypedArray()

                // Translate the foodNames to the correct locale if necessary
                val translatedFoodNames = if (Locale.getDefault().language != "en") {
                    cameraViewModel.translateFoodNames(foodNames)
                } else {
                    foodNames
                }

                runOnUiThread {
                    binding.cameraProgressCircle.hide()
                    showPredictionCheckDialog(translatedFoodNames)
                }
            }

        } catch (e: Exception) {
            Log.e("Inference", "Error running inference", e)
        }

FoodDetector.kt (onnx):

fun preprocessImageOnnx(inputStream: InputStream, ortEnv: OrtEnvironment): OnnxTensor {
        val inputWidth = 224
        val inputHeight = 224
        // Resize the bitmap
        val bitmap = BitmapFactory.decodeStream(inputStream)

        // Resize the image
        val resizedBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, true)

        // Convert the bitmap to a float array
        val floatArray = FloatArray(resizedBitmap.width * resizedBitmap.height * 3) // 3 channels for RGB
        for (y in 0 until resizedBitmap.height) {
            for (x in 0 until resizedBitmap.width) {
                val pixel = resizedBitmap.getPixel(x, y)
                val red = Color.red(pixel) / 255.0f
                val green = Color.green(pixel) / 255.0f
                val blue = Color.blue(pixel) / 255.0f

                // Assuming RGB order, you may need to adapt this to match your image format
                val index = (y * resizedBitmap.width + x) * 3
                floatArray[index] = red
                floatArray[index + 1] = green
                floatArray[index + 2] = blue
            }
        }

        println(floatArray.contentToString())

        val byteBuffer = ByteBuffer.allocateDirect(floatArray.size * 4).apply {
            asFloatBuffer().put(floatArray)
        }

        val inputTensor = OnnxTensor.createTensor(
            ortEnv,
            byteBuffer,
            longArrayOf(1, 3, inputHeight.toLong(), inputWidth.toLong()),
            OnnxJavaType.FLOAT
        )
        return inputTensor
    }

    fun detectOnnx(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): FloatArray {
        val inputTensor = preprocessImageOnnx(inputStream, ortEnv)

        inputTensor.use {
            // Step 4: Call ONNX inferenceSession run
            val output = ortSession.run(
                Collections.singletonMap("image", inputTensor),
                setOf("food_class")
            )

            Log.d("CNN out", output.toString())

            output.use {
                val classificationOutput = (output?.get(0)?.value) as Array<FloatArray>

                // Assuming classificationOutput is of shape (1, 256)
                val probabilities = classificationOutput[0]
                println(probabilities.contentToString())
                Log.e("Output Class", argmax(probabilities).toString())
                Log.e("Output Class", topK(probabilities, 5).contentToString())
                return probabilities
            }
        }
    }

FoodDetector.kt (ptl):

    fun preprocessImage(inputStream: InputStream): Tensor {
        //Decode the input stream into a bitmap
        val bitmap = BitmapFactory.decodeStream(inputStream)

        // Resize the image
        val resizedBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, true)

//        val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
//            resizedBitmap,
//            TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
//            TensorImageUtils.TORCHVISION_NORM_STD_RGB
//        )

        val floatArray = FloatArray(resizedBitmap.width * resizedBitmap.height * 3) // 3 channels for RGB

        for (y in 0 until resizedBitmap.height) {
            for (x in 0 until resizedBitmap.width) {
                val pixel = resizedBitmap.getPixel(x, y)
                val red = Color.red(pixel) / 255.0f
                val green = Color.green(pixel) / 255.0f
                val blue = Color.blue(pixel) / 255.0f

                // Assuming RGB order, you may need to adapt this to match your image format
                val index = (y * resizedBitmap.width + x) * 3
                floatArray[index] = red
                floatArray[index + 1] = green
                floatArray[index + 2] = blue
            }
        }
        
        println(floatArray.contentToString())
        println(floatArray.size)
        val inputTensor = Tensor.fromBlob(floatArray, longArrayOf(1, 3, IMAGE_WIDTH.toLong(), IMAGE_HEIGHT.toLong()))
        println(inputTensor.dataAsFloatArray.contentToString())
        println(inputTensor.shape().contentToString())
        inputStream.close()
        return inputTensor
    }

fun detect(inputStream: InputStream, module: Module): FloatArray{
        val inputTensor = preprocessImage(inputStream)
        val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
        val outputArray = outputTensor.dataAsFloatArray
        println(outputArray.contentToString())
        Log.e("Output Class", argmax(outputArray).toString())
        Log.e("Output Class", topK(outputArray, 5).contentToString())
        // TODO: use class names and return top 5 names
        return outputArray
    }

Note that I have already tried the classic approach in the tutorials:

 val inputStream: InputStream = assetManager.open(imageName)
BitmapFactory.decodeStream(inputStream)
...
val tensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB.,
        TensorImageUtils.TORCHVISION_NORM_STD_RGB)

And after a long set of experimenting, I haven’t gotten anywhere. The input tensor values on android are these:

[0.7882353, 0.78039217, 0.7921569, 0.8039216, 0.79607844, 0.80784315, 0.8039216, 0.79607844, 0.80784315, 0.8117647, 0.8039216, 0.8156863, 0.8156863, 0.80784315 ....

Which are slightly different than the ones PIL loads from pizza.jpg.
The onnx output tensor is this:

[NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN,

and the pytorch lite’s is this:

[-4563.728, -14917.566, -14551.325, -16742.281, -10390.229, -11888.001, -15174.114, -13337.571, ...
argmax: 86
topK: [86, 0, 35, 16, 19]

which leads to a completely wrong prediction.

There isn’t much documentation on this, even though a lot of people have had issues deploying on Android in the past. Any ideas why this is happening? Since using the documented implementations was unsuccessful, what is the right way to deploy this model to Android?

note that even with arrays/tensors full of 1.0f values, both models still produce different outputs. I also tried normalizing input during training (and android inference)