Hi, I’m attempting to run a two category image classification model in an app developed in Android Studio. I’ve been following a tutorial at the link here:
The code runs successfully with the resnet18 model used in the demo code, but when I attempt to use my model the app crashes as soon as I attempt to detect the object in the loaded image. I have also changed a file called ModelClasses.java to output my labels.
The Gradle module code is as follows.
plugins {
id 'com.android.application'
}
android {
compileSdkVersion 30
buildToolsVersion "30.0.3"
defaultConfig {
applicationId "com.example.crack_detection_java_demo"
minSdkVersion 24
targetSdkVersion 30
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar'])
implementation 'androidx.appcompat:appcompat:1.3.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
implementation 'org.pytorch:pytorch_android:1.8.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.8.0'
implementation 'org.pytorch:torchvision_ops:0.9.0'
implementation 'com.google.android.material:material:1.2.0-alpha03'
}
The MainActivity.java code is as follows
package com.example.crack_detection_java_demo;
import android.content.Context;
import android.content.Intent;
import android.database.Cursor;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.drawable.BitmapDrawable;
import android.net.Uri;
import android.os.Bundle;
import android.provider.MediaStore;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import androidx.appcompat.app.AppCompatActivity;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
public class MainActivity extends AppCompatActivity {
private static final int RESULT_LOAD_IMAGE = 1;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Button buttonLoadImage = findViewById(R.id.button);
Button detectButton = findViewById(R.id.detect);
requestPermissions(new String[]{android.Manifest.permission.READ_EXTERNAL_STORAGE}, 1);
buttonLoadImage.setOnClickListener(arg0 -> {
TextView textView = findViewById(R.id.result_text);
textView.setText("");
Intent i = new Intent(
Intent.ACTION_PICK,
MediaStore.Images.Media.EXTERNAL_CONTENT_URI);
startActivityForResult(i, RESULT_LOAD_IMAGE);
});
detectButton.setOnClickListener(arg0 -> {
Bitmap bitmap = null;
Module module = null;
//Getting the image from the image view
ImageView imageView = findViewById(R.id.image);
try {
//Read the image as Bitmap
bitmap = ((BitmapDrawable)imageView.getDrawable()).getBitmap();
//Here we reshape the image into 400*400
bitmap = Bitmap.createScaledBitmap(bitmap, 400, 400, true);
//Loading the model file.
module = Module.load(fetchModelFile(MainActivity.this, "20210613_CNNmodel_traced.pt"));
} catch (IOException e) {
finish();
}
//Input Tensor
final Tensor input = TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB
);
//Calling the forward of the model to run our input
assert module != null;
final Tensor output = module.forward(IValue.from(input)).toTensor();
final float[] score_arr = output.getDataAsFloatArray();
// Fetch the index of the value with maximum score
float max_score = -Integer.MAX_VALUE;
int ms_ix = -1;
for (int i = 0; i < score_arr.length; i++) {
if (score_arr[i] > max_score) {
max_score = score_arr[i];
ms_ix = i;
}
}
//Fetching the name from the list based on the index
String detected_class = ModelClasses.MODEL_CLASSES[ms_ix];
//Writing the detected class in to the text view of the layout
TextView textView = findViewById(R.id.result_text);
textView.setText(detected_class);
// textView.setText((CharSequence) output);
});
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
//This functions return the selected image from gallery
super.onActivityResult(requestCode, resultCode, data);
if (requestCode == RESULT_LOAD_IMAGE && resultCode == RESULT_OK && null != data) {
Uri selectedImage = data.getData();
String[] filePathColumn = { MediaStore.Images.Media.DATA };
Cursor cursor = getContentResolver().query(selectedImage,
filePathColumn, null, null, null);
cursor.moveToFirst();
int columnIndex = cursor.getColumnIndex(filePathColumn[0]);
String picturePath = cursor.getString(columnIndex);
cursor.close();
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(BitmapFactory.decodeFile(picturePath));
//Setting the URI so we can read the Bitmap from the image
imageView.setImageURI(null);
imageView.setImageURI(selectedImage);
}
}
public static String fetchModelFile(Context context, String modelName) throws IOException {
File file = new File(context.getFilesDir(), modelName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(modelName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
And the model classes file is:
package com.example.crack_detection_java_demo;
public class ModelClasses {
public static String[] MODEL_CLASSES = new String[]{
"Negative",
"Positive"
};
}
Android Studio doesn’t give any errors and I’m at a loss to begin troubleshooting the code. Could anyone advise on how I should approach troubleshooting this?