Image classification app crashing

Hi, I’m attempting to run a two category image classification model in an app developed in Android Studio. I’ve been following a tutorial at the link here:

The code runs successfully with the resnet18 model used in the demo code, but when I attempt to use my model the app crashes as soon as I attempt to detect the object in the loaded image. I have also changed a file called ModelClasses.java to output my labels.

The Gradle module code is as follows.

plugins {
    id 'com.android.application'
}

android {
    compileSdkVersion 30
    buildToolsVersion "30.0.3"

    defaultConfig {
        applicationId "com.example.crack_detection_java_demo"
        minSdkVersion 24
        targetSdkVersion 30
        versionCode 1
        versionName "1.0"

        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
    }

    buildTypes {
        release {
            minifyEnabled false
            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
        }
    }
    compileOptions {
        sourceCompatibility JavaVersion.VERSION_1_8
        targetCompatibility JavaVersion.VERSION_1_8
    }
}

dependencies {
    implementation fileTree(dir: 'libs', include: ['*.jar'])
    implementation 'androidx.appcompat:appcompat:1.3.0'
    implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
    implementation 'org.pytorch:pytorch_android:1.8.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.8.0'
    implementation 'org.pytorch:torchvision_ops:0.9.0'
    implementation 'com.google.android.material:material:1.2.0-alpha03'
}

The MainActivity.java code is as follows

package com.example.crack_detection_java_demo;

import android.content.Context;
import android.content.Intent;
import android.database.Cursor;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.drawable.BitmapDrawable;
import android.net.Uri;
import android.os.Bundle;
import android.provider.MediaStore;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;

import androidx.appcompat.app.AppCompatActivity;

import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

public class MainActivity extends AppCompatActivity {
    private static final int RESULT_LOAD_IMAGE = 1;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        Button buttonLoadImage = findViewById(R.id.button);
        Button detectButton = findViewById(R.id.detect);


        requestPermissions(new String[]{android.Manifest.permission.READ_EXTERNAL_STORAGE}, 1);
        buttonLoadImage.setOnClickListener(arg0 -> {
            TextView textView = findViewById(R.id.result_text);
            textView.setText("");
            Intent i = new Intent(
                    Intent.ACTION_PICK,
                    MediaStore.Images.Media.EXTERNAL_CONTENT_URI);

            startActivityForResult(i, RESULT_LOAD_IMAGE);


        });

        detectButton.setOnClickListener(arg0 -> {

            Bitmap bitmap = null;
            Module module = null;

            //Getting the image from the image view
            ImageView imageView = findViewById(R.id.image);

            try {
                //Read the image as Bitmap
                bitmap = ((BitmapDrawable)imageView.getDrawable()).getBitmap();

                //Here we reshape the image into 400*400
                bitmap = Bitmap.createScaledBitmap(bitmap, 400, 400, true);

                //Loading the model file.
                module = Module.load(fetchModelFile(MainActivity.this, "20210613_CNNmodel_traced.pt"));
            } catch (IOException e) {
                finish();
            }

            //Input Tensor
            final Tensor input = TensorImageUtils.bitmapToFloat32Tensor(
                    bitmap,
                    TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
                    TensorImageUtils.TORCHVISION_NORM_STD_RGB
            );

            //Calling the forward of the model to run our input
            assert module != null;
            final Tensor output = module.forward(IValue.from(input)).toTensor();
            final float[] score_arr = output.getDataAsFloatArray();



            // Fetch the index of the value with maximum score
            float max_score = -Integer.MAX_VALUE;
            int ms_ix = -1;
            for (int i = 0; i < score_arr.length; i++) {
                if (score_arr[i] > max_score) {
                    max_score = score_arr[i];
                    ms_ix = i;
               }
            }

            //Fetching the name from the list based on the index
            String detected_class = ModelClasses.MODEL_CLASSES[ms_ix];

            //Writing the detected class in to the text view of the layout
            TextView textView = findViewById(R.id.result_text);
            textView.setText(detected_class);
//            textView.setText((CharSequence) output);


        });

    }
    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        //This functions return the selected image from gallery
        super.onActivityResult(requestCode, resultCode, data);

        if (requestCode == RESULT_LOAD_IMAGE && resultCode == RESULT_OK && null != data) {
            Uri selectedImage = data.getData();
            String[] filePathColumn = { MediaStore.Images.Media.DATA };

            Cursor cursor = getContentResolver().query(selectedImage,
                    filePathColumn, null, null, null);
            cursor.moveToFirst();

            int columnIndex = cursor.getColumnIndex(filePathColumn[0]);
            String picturePath = cursor.getString(columnIndex);
            cursor.close();

            ImageView imageView = findViewById(R.id.image);
            imageView.setImageBitmap(BitmapFactory.decodeFile(picturePath));

            //Setting the URI so we can read the Bitmap from the image
            imageView.setImageURI(null);
            imageView.setImageURI(selectedImage);


        }


    }

    public static String fetchModelFile(Context context, String modelName) throws IOException {
        File file = new File(context.getFilesDir(), modelName);
        if (file.exists() && file.length() > 0) {
            return file.getAbsolutePath();
        }

        try (InputStream is = context.getAssets().open(modelName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }

}

And the model classes file is:

package com.example.crack_detection_java_demo;

public class ModelClasses {
    public static String[] MODEL_CLASSES = new String[]{
            "Negative",
            "Positive"
    };
}

Android Studio doesn’t give any errors and I’m at a loss to begin troubleshooting the code. Could anyone advise on how I should approach troubleshooting this?

Just to clarify, I don’t need help figuring out why my app is crashing, just what methods I should undertake to troubleshoot the issues I’m experiencing or where I can read up on troubleshooting methods that would be applicable.

@IvanKobzarev can you provide any help here?

Eventually solved the problem, just leaving this here in case anyone else is ever in the same boat. When I traced my model in pytorch I was sending my model to the GPU. I changed this to the CPU and everything started working. I don’t know why, I don’t know if there was something else going on, all I know is that things started working after I made this change.