[Android] Process gets killed after forwarding to the model

I’m trying to use this SuperPoint model in my Android application. It has pretrained weights and I save it for mobile like so:

module = SuperPointNet()
module.load_state_dict(torch.load("..."))
module.eval()

torch.quantization.fuse_modules(
    module,
    modules_to_fuse=[["conv1a", "relu1a"], ["conv1b", "relu1b"],
                     ["conv2a", "relu2a"], ["conv2b", "relu2b"],
                     ["conv3a", "relu3a"], ["conv3b", "relu3b"],
                     ["conv4a", "relu4a"], ["conv4b", "relu4b"],
                     ["convPa", "reluPa"], ["convDa", "reluDa"]],
    inplace=True
)

scripted = torch.jit.script(module)
scripted._save_for_lite_interpreter("...")

The saved model works absolutely fine with images somewhere up to 720p. But when I try to forward larger images, it runs for a while, then UI starts lagging, I get the messages like these in logs (these are taken from an emulator, but on real devices the behaviour is the same), and eventually the process gets killed by the system (because the app stops responding, I figure):

2022-08-15 08:57:54.339 7249-7249/com.github.kpdandroid I/Choreographer: Skipped 30 frames!  The application may be doing too much work on its main thread.
2022-08-15 08:57:54.965 7249-7280/com.github.kpdandroid D/EGL_emulation: app_time_stats: avg=1255.40ms min=1255.40ms max=1255.40ms count=1
2022-08-15 08:57:55.080 7249-7552/com.github.kpdandroid I/OpenGLRenderer: Davey! duration=2458ms; Flags=0, FrameTimelineVsyncId=24646, IntendedVsync=439844396372, Vsync=440361063018, InputEventId=0, HandleInputStart=440834660800, AnimationStart=440834725000, PerformTraversalsStart=441327235600, DrawStart=441465262000, FrameDeadline=439877729704, FrameInterval=440362785000, FrameStartTime=16666666, SyncQueued=441576657100, SyncStart=441607291100, IssueDrawCommandsStart=441639554200, SwapBuffers=442253047100, FrameCompleted=442333966600, DequeueBufferDuration=98000, QueueBufferDuration=37406400, GpuCompleted=442298839300, SwapBuffersCompleted=442333966600, DisplayPresentTime=35188667079900, 

Of course, I don’t run call forward() on main thread. This happends both when I use CameraX with ImageAnalysis executed on a dedicated Executors.newSingleThreadExecutor() and when I try to run the model on images from file system in a coroutine launched in Dispatchers.Default context (I also tried launching on a dedicated thread, but it doesn’t help).

I tried to set the number of threads PyTorch uses to a single one with PyTorchAndroid.setNumThreads(1), but to no avail.

What can I do with this?