Result mismatch of Desktop and Android for custom mobilenetv2 image classification

Environment:
Desktop:
Pytorch:1.9.0+cu111
Android:
pytorch_android_lite:1.9.0
pytorch_android_torchvision:1.9.0

I customize MobileNetv2 Pretrained network for binary image classification.

model.classifier[1] = torch.nn.Linear(in_features=model.classifier[1].in_features, out_features=2)

I used transformation while prediction as follow:

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  ])

Convert it to mobile deployable format - TorchScript as follow:

with torch.no_grad():
    model= torchvision.models.mobilenet_v2(pretrained=True)
    model.classifier[1] = torch.nn.Linear(in_features=model.classifier[1].in_features, out_features=2)
    m = torch.load(<model.pth path>, map_location=torch.device('cpu'))
    model.load_state_dict(m)
    model = torch.nn.Sequential(model,torch.nn.Softmax(1))
    model = model.to('cpu')
    model = model.eval()
    example = torch.rand(1, 3, 224, 224)
    q_model = torch.quantization.convert(model)
    model_script = torch.jit.trace(q_model,example)
    optimized_scripted_module = optimize_for_mobile(model_script)
    optimized_scripted_module._save_for_lite_interpreter('model.ptl')

I tried the mobile deployable model on the desktop and see the result. It is 99% matching, slightly varying in decimal point but It’s ok.

I deployed it on Android but the results are not matching. It’s much different.

Bitmap scaled_bitmap= Bitmap.createScaledBitmap(bitmap,224,224,false);
  // preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(scaled_bitmap,
            TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CONTIGUOUS);
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
final float[] scores = outputTensor.getDataAsFloatArray();

Desktop result: [0.9320468306541443, 0.06795316189527512]
Android result: [0.30155882, 0.69844115 ]

I tried with static input with all one. The result are matching.

Desktop result : [0.9261210560798645, 0.07387895882129669]
Android result: [0.9261202,0.07387978]

I think, there is something wrong with preprocessing of images. I am not sure pytorch transformation function and Android preprocessing function are too equivalent.

Any help regarding this?

Could you try to use a static input (e.g. torch.ones) and compare the outputs again? If they are matching, it would indicate that the data loading and processing might be the root cause. Otherwise the mismatch might be created in the model itself.

I already tried with static input (e.g. torch.ones) with all one. The result are matching.

Desktop result : [0.9261210560798645, 0.07387895882129669]
Android result: [0.9261202,  0.07387978]

I have followed HelloWorldApp example. Do we need extra preprocessing?

OK, this would indeed point to a data loading issue.
I don’t know if extra preprocessing is needed, but would try to check if the inputs are interleaved, in another color format (BGR vs. RGB) etc.

I can see the resize in 224 in Android degrade the quality of images. For comparison purpose, I checked it by resize and save images in Android then I tried to resize it on desktop, downloaded the image:

from torchvision import transforms
from PIL import Image
image = Image.open(<path of image>)
transform= transforms.Resize((224, 224))
img= transform(image)
display(img)

and tried the image which was saved with the desktop transformation on Android then the results are exactly matching.

Android code for resizing,
Bitmap inp_bitmap= Bitmap.createScaledBitmap(bitmap,224,224,false);

What will be a good image resize utility which is equivalent to the PyTorch transformation resize function?

Any help is appreciated.

Default resize mode for torchvision resize is bilinear,
as for android function you mentioned, False means fast method, nearest neighborhood, and True is bilinear, slower one

Thanks for reply,

I tried that option too.
Bitmap inp_bitmap= Bitmap.createScaledBitmap(bitmap,224,224,true);

The results are not that promising.

I just, tried Android Picasso and Glide library and applied resize and centerInside function The results are good but still not 100% accurate.