Hi, I’m trying to do a mnist classifier with pytorch-mobile.
I’ve export my model with jit in python.
example = torch.rand(1, 1, 28, 28)
trace = torch.jit.trace(net, example)
trace.save("mymodel.pt")
And this is how I use it in my kotlin project.
fun assetFilePath(context: Context, assetName: String): String {
val file = File(context.filesDir, assetName)
if (file.exists() && file.length() > 0) {
return file.absolutePath
}
context.assets.open(assetName).use { inputStream ->
FileOutputStream(file).use { outputStream ->
val buffer = ByteArray(4 * 1024)
var read: Int
while (inputStream.read(buffer).also { read = it } != -1) {
outputStream.write(buffer, 0, read)
}
outputStream.flush()
}
return file.absolutePath
}
}
val module:Module = Module.load(assetFilePath(this, "mymodel.pt"))
val resizedBitmap:Bitmap = Bitmap.createScaledBitmap(
MyCanvasView.extraBitmap,
28,
28,
true
)
val inputTensor: Tensor = TensorImageUtils.bitmapToFloat32Tensor(
resizedBitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB
)
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
the problem is inputTensor’s shape is [1, 3, 28, 28]. How can I create a single channel tensor from a bitmap?
java.lang.RuntimeException: Given groups=1, weight of size 8 1 3 3, expected input[1, 3, 28, 28] to have 1 channels, but got 3 channels instead
The above operation failed in interpreter.