1 channel input on Pytorch Mobile - Android

Hi, I’m trying to do a mnist classifier with pytorch-mobile.
I’ve export my model with jit in python.

example = torch.rand(1, 1, 28, 28)
trace = torch.jit.trace(net, example)
trace.save("mymodel.pt")

And this is how I use it in my kotlin project.

        fun assetFilePath(context: Context, assetName: String): String {
            val file = File(context.filesDir, assetName)
            if (file.exists() && file.length() > 0) {
                return file.absolutePath
            }
            context.assets.open(assetName).use { inputStream ->
                FileOutputStream(file).use { outputStream ->
                    val buffer = ByteArray(4 * 1024)
                    var read: Int
                    while (inputStream.read(buffer).also { read = it } != -1) {
                        outputStream.write(buffer, 0, read)
                    }
                    outputStream.flush()
                }
                return file.absolutePath
            }
        }

        val module:Module = Module.load(assetFilePath(this, "mymodel.pt"))

        val resizedBitmap:Bitmap = Bitmap.createScaledBitmap(
            MyCanvasView.extraBitmap,
            28,
            28,
            true
        )

        val inputTensor: Tensor = TensorImageUtils.bitmapToFloat32Tensor(
            resizedBitmap,
            TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
            TensorImageUtils.TORCHVISION_NORM_STD_RGB
        )

        val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()

the problem is inputTensor’s shape is [1, 3, 28, 28]. How can I create a single channel tensor from a bitmap?

java.lang.RuntimeException: Given groups=1, weight of size 8 1 3 3, expected input[1, 3, 28, 28] to have 1 channels, but got 3 channels instead
    The above operation failed in interpreter.

2 Likes

Hi, I am also facing a similar issue of using 1 channel images in the PyTorch Android. Did you find any way out?

Hi, you can try this way.

Make sure that, x = torch.mean(x,dim=1,keepdim=True), is the first statement in the forward method of pytorch model.

Then,

example = torch.rand(1, 3, 28, 28)
trace = torch.jit.trace(net, example)
trace.save(“mymodel.pt”)