I have a problem that a lot of people have also experienced.
I trained a model/classifier loading and image from .jpg file using this code:
def __init__(self, ...):
...
self.transform = transforms.Compose([transforms.Resize((W, H)), transforms.ToTensor()])
def __getitem__(self, idx):
# image = read_image(self.image_files[idx][1])
image = self.transform(Image.open(self.image_files[idx][1]))
return {"image": image, "label": np.array(self.image_files[idx][0]), "path": self.image_files[idx][1]}
...
# prepare tensor for model forwarding
X = sample_batched["image"].float().to(self.device)
...
out = model(out)
def torch_to_mobile():
dataset = FoodDataset(BASE_DIR)
model = FoodClassifier(len(dataset.classes))
model = load_model(MODEL_PATH, model)
model.eval()
print(model)
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
out_path = MODEL_PATH.split(".")[0] + ".ptl"
traced_script_module_optimized._save_for_lite_interpreter(out_path)
print("Model saved at:", out_path)
and deployed a small test on android studio.
build.gradle:
implementation 'org.pytorch:pytorch_android_lite:1.13.1'
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.13.1'
CameraActivity.kt:
// torchscript
val module = LiteModuleLoader.load(assetFilePath(this, "FoodClassifier.ptl"))
// val module = Module.load(assetFilePath(this, "FoodClassifier.pt"))
var foodDetector = FoodDetector()
var imageStream = assets.open("model/pizza.jpg") // TODO: replace with imagePath
var model_out = foodDetector.detect(imageStream, module)
FoodDetector.kt:
fun preprocessImage(inputStream: InputStream): Tensor {
//Decode the input stream into a bitmap
val bitmap = BitmapFactory.decodeStream(inputStream)
// Resize the image
val resizedBitmap = Bitmap.createScaledBitmap(bitmap, IMAGE_WIDTH, IMAGE_HEIGHT, true)
val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
)
println(inputTensor.dataAsFloatArray.contentToString())
return inputTensor
}
fun argmax(array: FloatArray): Int {
var maxIndex = 0
var maxValue = Float.NEGATIVE_INFINITY
for (i in array.indices) {
if (array[i] > maxValue) {
maxValue = array[i]
maxIndex = i
}
}
return maxIndex
}
fun topK(a: FloatArray, topk: Int): IntArray {
val values = FloatArray(topk)
Arrays.fill(values, -java.lang.Float.MAX_VALUE)
val ixs = IntArray(topk)
Arrays.fill(ixs, -1)
for (i in a.indices) {
for (j in 0 until topk) {
if (a[i] > values[j]) {
for (k in topk - 1 downTo j + 1) {
values[k] = values[k - 1]
ixs[k] = ixs[k - 1]
}
values[j] = a[i]
ixs[j] = i
break
}
}
}
return ixs
}
// FIXME: WRONG CLASS OUTPUT
fun detect(inputStream: InputStream, module: Module): FloatArray{
val inputTensor = preprocessImage(inputStream)
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
val outputArray = outputTensor.dataAsFloatArray
println(outputArray.contentToString())
Log.e("Output Class", argmax(outputArray).toString())
Log.e("Output Class", topK(outputArray, 5).contentToString())
// TODO: use class names and return top 3 names
return outputArray
}
The input and output values of the pytorch inference code are:
input image:
torch.Size([3, 224, 224]) 17 ./data/18/1702.jpg
tensor([[[0.7922, 0.8039, 0.8039, ..., 0.4392, 0.5412, 0.5020],
[0.8078, 0.8118, 0.8157, ..., 0.4667, 0.5098, 0.5059],
[0.8157, 0.8157, 0.8235, ..., 0.4902, 0.5176, 0.4941],
...,
[0.9569, 0.9569, 0.9529, ..., 0.8824, 0.8745, 0.8706],
[0.9569, 0.9569, 0.9529, ..., 0.8745, 0.8706, 0.8667],
[0.9529, 0.9490, 0.9490, ..., 0.8627, 0.8588, 0.8549]],
[[0.7843, 0.7961, 0.7961, ..., 0.4431, 0.5373, 0.4863],
[0.8000, 0.8039, 0.8078, ..., 0.4667, 0.5098, 0.4941],
[0.8078, 0.8078, 0.8157, ..., 0.4706, 0.5059, 0.4745],
...,
[0.9569, 0.9569, 0.9529, ..., 0.8824, 0.8784, 0.8745],
[0.9529, 0.9529, 0.9490, ..., 0.8784, 0.8745, 0.8706],
[0.9490, 0.9451, 0.9451, ..., 0.8784, 0.8745, 0.8706]],
[[0.7961, 0.8078, 0.8078, ..., 0.3961, 0.5255, 0.4667],
[0.8118, 0.8157, 0.8196, ..., 0.3961, 0.4510, 0.4549],
[0.8196, 0.8196, 0.8275, ..., 0.4667, 0.4627, 0.4549],
...,
[0.9569, 0.9490, 0.9451, ..., 0.9020, 0.8941, 0.8863],
[0.9451, 0.9451, 0.9373, ..., 0.8902, 0.8824, 0.8784],
[0.9373, 0.9333, 0.9294, ..., 0.8863, 0.8745, 0.8667]]])
output tensor:
tensor([[ -6.4250, -5.6502, -9.1750, -6.7315, -6.1992, -4.3728, -7.5658,
-9.0300, -5.0554, -8.4527, -8.4998, -6.2553, -4.4357, -10.0883,
-5.4501, -4.1953, -4.1803, 1.5447, -5.4573, -8.1641, -8.8581,
-3.8452, -5.7401, -6.1678, -8.2289, -5.5331, -5.4196, -4.3286,
-8.2148, -7.4791, -7.9286, -7.2004, -6.6325, -6.5702, -7.2533,
-5.0552, -7.9561, -7.6686, -9.8013, -4.1677, -8.1728, -7.2377,
-8.6816, -9.3092, -6.0697, -5.9184, -6.5978, -6.0070, -7.9605,
-8.2251, -10.0780, -9.1477, -7.7357, -6.9407, -9.6364, -7.6766,
-8.1835, -8.6580, -9.2632, -7.4976, -8.6892, -7.7322, -8.5101,
-2.8730, -7.6723, -10.3767, -6.1771, -6.8715, -6.4986, -6.5030,
-9.9414, -8.1596, -9.3136, -9.7785, -9.1811, -5.0005, -6.1989,
-7.3114, -5.0766, -3.7111, -8.5137, -4.7980, -9.5621, -5.0810,
-7.7581, -9.5593, -5.4868, -8.5949, -6.6317, -8.1714, -6.6684,
-8.0697, -9.6710, -6.6122, -3.2688, -6.6061, -6.7667, -8.8379,
-6.9095, -9.3807, -7.8291, -9.1676, -7.9217, -7.5481, -6.5465,
-7.8089, -7.8403, -9.2809, -4.9770, -3.7320, -3.9532, -9.1324,
-8.4911, -5.0800, -9.1694, -7.4032, -8.2069, -3.2996, -5.1245,
-3.7846, -9.7733, -8.3546, -5.6607, -8.4584, -9.5602, -6.6742,
-5.9762, -5.0890, -10.2585, -9.5369, -8.8205, -10.1797, -7.0062,
-7.3069, -8.5791, -9.0843, -7.4706, -8.1020, -7.8693, -6.7647,
-6.9144, -9.1066, -7.0260, -8.1277, -5.7195, -5.1588, -2.2246,
-7.1596, -4.9242, -6.5084, -6.9575, -5.9987, -4.5282, -7.2608,
-3.6212, -10.4534, -8.7601, -6.4402, -7.1080, -6.2117, -7.6964,
-3.3375, -6.1063, -7.0410, -8.1057, -10.1912, -8.3852, -7.2990,
-9.3008, -11.2928, -5.9536, -8.6684, -6.2913, -7.2556, -10.7067,
-12.5549, -6.6373, -6.4069, -9.3014, -10.7260, -4.8521, -8.0948,
-4.0307, -7.4884, -5.5963, -7.6681, -8.5791, -10.4596, -8.1598,
-5.5236, -4.5662, -7.4979, -9.3129, -8.0146, -9.4046, -4.0687,
-8.6144, -7.0490, -6.4512, -6.4222, -6.7708, -4.7155, -4.3436,
-6.9712, -5.0127, -6.2544, -4.3396, -3.6489, -7.5470, -8.7146,
-6.7608, -8.7649, -2.1415, -8.3589, -10.9746, -6.6287, -10.1222,
-10.3195, -7.5870, -7.4546, -7.7881, -8.1478, -11.1285, -8.2365,
-11.6734, -4.1729, -3.8706, -7.9013, -7.5459, -6.8718, -8.3584,
-4.5214, -7.5231, -3.4609, -6.9615, -5.7563, -5.6318, -6.6918,
-9.4669, -9.6109, -12.2082, -3.7341, -7.6879, -8.2632, -7.5099,
-10.9171, -9.8948, -6.2009, -7.2751, -4.1637, -8.0830, -7.1027,
-7.4779, -7.9212, -3.2335, -5.6123]])
argmax = 17
Prediction: pizza
while in kotlin I get this:
input:
1.3241715, 1.3926706, 1.3926706, 1.42692, 1.4440448, 1.4611696, 1.4611696, 1.4611696, 1.4611696, 1.4611696, 1.4611696, 1.4782944, 1.4782944, 1.4782944, 1.4782944, 1.4782944, 1.4440448, 1.4611696, 1.4440448, 1.4611696, 1.4782944, 1.4782944, 1.4782944, 1.4954191, 1.5296686, 1.5125438, 1.5296686, 1.4782944, 1.4954191, 1.4782944, 1.4782944, 1.4782944, 1.5125438, 1.4954191, 1.4782944, 1.4954191, 1.4954191, 1.4954191, 1.5125438, 1.4954191, 1.4782944, 1.4782944, 1.4954191, 1.5296686, 1.5125438, 1.5296686, 1.5296686, 1.4782944, 1.4611696, 1.4611696, 1.4782944, 1.4782944, 1.4954191, 1.5125438, 1.4954191, 1.4954191, 1.5125438, 1.5125438, 1.4440448, 1.4954191, 1.5467933, 1.5296686, 1.4611696, 1.4954191, 1.5296686, 1.3412963, 0.5535577, 0.96455175, 1.1700488, 1.3755459, 1.2556726, 1.1529241, 1.0673003, 0.878928, 0.34806067, -0.49105233, -1.3301654, -1.5870366, -1.5699118, -1.5185376, -1.4671633, -1.4329139, -1.4842881, -1.4500387, -1.4671633, -1.1760426, -1.3644148, -1.3815396, -1.3472902, -1.4157891, -1.4671633, -0.95342064, -1.6212862, -1.6555357, -1.6726604, -1.6212862, -1.5699118, -1.6041614, -1.6555357, -1.5699118, -1.4500387, -1.2787911, -1.0219197, -0.9705454, -0.79929787, -0.91917115, -0.9876702, -1.1075436, -1.141793, -1.4329139, -1.4671633, -1.6041614, -1.5870366, -1.6041614, -1.6384109, -1.5699118, -1.5870366, -1.5356624, -0.8335474, -0.7136741, 0.15968838, 0.4850587, 0.5364329, 0.5193082, 0.1939379, -0.95342064, -1.6041614, -1.6555357, -1.6555357, -1.5870366, -1.4500387, -0.7307989, -1.5014129, -1.5356624, -1.4500387, -1.6041614, -1.4500387, -0.9705454, -0.57667613, -1.2274169, -0.8506721, 0.5364329, 1.2727973, 1.5981677, 1.5639181, 1.5981677, 1.5981677, 1.6152923, 1.5981677, 1.5467933, 1.5810429, 1.6152923, 1.6324171, 1.6324171, 1.6324171, 1.6324171, 1.6152923, 1.5981677, 1.5981677, 1.6324171, 1.6495419, 1.6495419, 1.6324171, 1.5810429, 1.5467933, 1.5810429, 1.5639181, 1.5639181, 1.5467933, 1.5467933, 1.5639181, 1.5296686, 1.4954191, 1.4782944, 1.4954191, 1.5125438, 1.42692, 0.60493195, 1.2727973, 0.15968838, 0.34806067, -0.2855553, -0.11430778, 0.2966864, 0.022690238, 0.57068247, 0.79330426, 0.46793392, 0.108314134, 0.2966864, 0.6220567, 0.43368444, -0.2855553, 0.45080918, 0.79330426, 0.46793392, -0.011559267, 0.1939379, 0.6220567, 0.5878072, 0.2281874, 0.022690238, 0.45080918, 0.57068247, 0.57068247, 0.12543888, -0.02868402, 0.056939743, -0.33692956, -0.45680285, -0.35405433, -0.69654936, -0.7479236, -0.67942464, -0.6451751, -0.30268008, 0.34806067, 0.07406463, -0.16568205, -0.2170563, -0.43967807, -0.37117907, 0.3651854, -0.06293353, 1.3926706, 1.42692, 1.42692, 1.4611696, 1.4782944, 1.4782944, 1.4782944, 1.4954191, 1.4954191, 1.4954191, 1.5125438, 1.5125438, 1.5125438, 1.4954191, 1.4954191, 1.5296686, 1.4954191, 1.4954191, 1.4954191, 1.4782944, 1.4782944, 1.5125438, 1.4954191, 1.5125438, 1.5296686, 1.5296686, 1.5467933, 1.5296686, 1.5296686, 1.5296686, 1.5296686, 1.5296686, 1.5296686, 1.4782944, 1.4782944, 1.5125438, 1.5125438, 1.5125438, 1.5125438, 1.5125438, 1.5125438, 1.5125438, 1.5467933, 1.5467933, 1.4954191, 1.4954191, 1.5125438, 1.4954191, 1.4954191, 1.5296686, 1.4954191, 1.4954191, 1.5125438, 1.5296686, 1.5296686, 1.5467933, 1.5467933, 1.5639181, 1.5296686, 1.5296686, 1.5125438, 1.5296686, 1.5467933, 1.5296686, 1.4611696, 0.17681314, 0.79330426, 1.1186745, 1.2556726, 1.4954191, 0.79330426, 0.45080918, -0.11430778, -0.9020464, -1.2787911, -1.3815396, -1.5185376, -1.6041614, -1.6212862, -1.5699118, -1.5014129, -1.4842881, -1.5185376, -1.5870366, -1.4671633, -1.2102921, -1.4329139, -0.31980482, 1.0844251, 0.07406463, -1.0390445, -1.3644148, -1.5014129, -1.5870366, -1.6555357, -1.6041614, -1.6897851, -1.7411594, -1.6384109, -1.6212862, -1.7069099, -1.8096584, -1.8267832, -1.5870366, -1.7754089, -1.6555357, -1.5870366, -1.3130406, -0.14855729, 0.07406463, 0.27956167, 0.45080918, 0.108314134, -0.26843056, -0.5424266, -1.2102921, -1.4329139, -0.78217316, 1.1186745, 1.3584211, 1.2899221, 1.4097953, 1.3926706, 1.3070468, 1.4097953, -0.43967807, -1.6212862, -1.6212862, -1.6726
output:
[-30571.434, -68474.21, -58631.973, -58655.77, -42851.656, -49321.5, -74000.78, -45550.594, -42487.66, -47610.95, -47112.367, -50195.094, -80131.12, -71236.83, -62103.164, -57046.42, -38206.94, -53553.215, -59162.234, -33317.25, -47081.96, -56312.715, -52878.688, -58229.375, -61969.867, -48885.074, -48899.348, -48952.797, -65049.375, -58910.6, -48149.86, -48899.508, -61258.137, -58512.56, -45999.316, -27279.934, -67456.63, -43162.492, -54333.75, -52835.52, -52082.17, -48469.625, -74907.0, -70001.73, -48904.617, -59542.113, -70019.83, -62489.633, -63455.734, -50524.953, -63811.457, -63315.016, -51634.258, -58028.297, -47732.03, -49585.586, -61240.906, -72493.79, -54116.28, -36536.13, -56610.09, -84276.18, -48778.145, -62885.145, -68803.664, -68755.25, -49887.8, -38417.86, -54181.51, -43047.52, -63574.594, -56797.684, -64269.664, -58621.4, -65094.836, -43314.88, -50006.777, -73834.35, -64065.855, -41414.45, -57358.688, -56874.336, -62199.562, -52377.223, -68186.0, -47789.2, -30251.148, -45542.31, -64639.17, -50332.48, -47525.805, -32910.914, -59852.234, -54682.273, -66385.18, -57978.34, -70564.95, -40002.62, -47921.38, -51943.06, -60722.63, -58308.273, -64530.25, -67543.266, -67889.17, -56492.355, -62865.33, -46185.566, -45770.844, -70346.53, -57840.74, -75518.64, -57075.11, -63876.586, -65690.56, -55782.9, -58676.832, -58708.035, -68588.14, -64086.01, -60896.445, -63098.8, -60064.04, -52715.16, -66535.88, -61510.812, -47626.742, -73888.375, -58530.805, -70084.61, -61040.992, -76335.8, -72583.86, -69310.805, -75871.484, -52367.4, -70609.01, -58147.6, -64791.324, -56125.38, -58661.89, -53612.29, -44112.125, -71925.445, -64066.24, -41409.543, -58928.977, -54544.273, -58300.254, -42870.027, -51363.875, -69842.9, -54098.38, -50656.51, -59357.914, -85024.78, -79363.766, -48345.496, -72155.81, -53595.113, -63647.656, -58873.555, -58428.95, -69631.47, -72781.26, -88182.734, -62204.895, -54846.152, -51062.11, -58996.75, -66660.805, -65060.32, -71382.54, -42450.395, -66750.2, -88973.8, -72162.55, -39830.406, -64316.156, -51310.184, -68144.24, -59078.906, -79519.65, -47931.77, -52535.75, -68417.2, -64661.246, -65523.582, -68096.01, -48051.195, -55878.38, -60802.79, -56804.562, -54431.117, -67300.89, -56095.617, -63731.19, -65780.5, -53809.26, -53394.867, -70708.01, -53738.406, -49251.164, -49791.53, -63830.4, -68972.695, -69670.914, -67653.63, -60360.406, -64695.46, -51332.703, -66104.2, -58346.285, -83154.0, -84073.09, -46853.6, -63334.457, -63052.02, -60205.98, -74087.55, -70496.31, -69451.69, -61029.27, -59428.562, -70393.96, -57312.105, -43012.18, -59892.53, -56409.316, -61937.855, -74559.3, -70559.36, -44311.426, -58755.56, -47393.746, -52536.26, -43792.344, -51401.67, -66035.95, -70937.68, -63233.793, -55476.37, -52954.754, -83635.3, -70090.95, -72397.42, -63570.58, -66000.67, -73642.28, -49205.938, -57886.99, -69780.07, -65842.164, -79370.42, -55849.273, -69387.66]
argmax: 35