Rookie problem: spiralling memory usage on inference

Hi! I’m transitioning from Tensorflow to PyTorch in my android application, and I’m running into a problem when I do inference - memory usage increases quickly until the app crashes. Android Studio reports that all of the memory usage is Native. I’m clearly abusing the library somewhere, or not releasing something that needs releasing. I’ve looked through most of the sample apps and can’t find any clues.

I’m processing 15 frames per second from a screen recording.

My code looks like this. A “slice” is just a data object holding the bitmaps I’m using (cropped, resized, original capture). I know those are being freed because when I remove the inference step but leave the rest of the image pipeline memory usage is fine.

class EvaluationFilter: Filter {
    // Lazy initialization so we can ensure the context is set up in the service.
    private val module: Module? by lazy {
        val path = assetFilePath("image_qual3.ptl")
        LiteModuleLoader.load(path)
        //Module.load(assetFilePath("image_qual1.pt"))
    }

    // Perform inference on a screen capture. Called about 15 times per second
    // while the user is recording.
    override fun processSlice(slice: Slice): Slice {
        slice.croppedScaledBitmap?.let { bitmap ->
            val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
                TensorImageUtils.TORCHVISION_NORM_STD_RGB,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB)
            val output = module?.forward(IValue.from(inputTensor))?.toTensor()
            val scores = output?.dataAsFloatArray
            scores?.let { slice.quality = it[1] }
            Log.d(TAG, "Scores are ${scores?.contentToString()}")
        }
        return slice
    }
}

Thanks for any help, I’m a bit of a loss. I’m sure it will be a forehead-slapper when it’s revealed.

I am not very familiar with android app development. Can you elaborate on what you mean here?

If the issue is memory leak in pytorch runtime, then we may have to figure out where it is coming from. Is this quantized model or floating point model? Do you know?

Hi, I’m having the same problem with this too, I use Pytorch already for arround 2 years and for this long I have been looking for an answer or some way to release the used memory that increases constantly. Now I am in need to process more frames per second and the app is crashing faster.

I have a simmilar code to the original question from @rfrey

I figured out that the native memory increases always on the module.forward() call.

I already tryed to reuse the FloatBuffer instance like this, freing/rewinding it, calling tensor.mHybridData.resetNative(); and many other ways… I just couldn’t use imageYUV420CenterCropToFloatBuffer method since I have a Bitmap as image source. I tryed to convert the Bitmap to YUV420, but I just get NV21 format on a byte[] (pytorch requires the YUV Image 3 planes), which is not compatible with pytorch…

Have someone reached a solution for the memory that could please share it with us? It could be also usefull some way to convert the bitmap to YUV420 getting its image planes.