Torch.jit.trace() memory leak workarounds

It has been reported elsewhere that torch.jit.trace() has a memory leak. This is a problem for me because my application repeatedly calls torch.jit.trace().save() in a loop (later loading the saved model with torchlib in c++).

One workaround suggested here is to wrap the torch.jit.trace() call in a separate process by using futures.ProcessPoolExecutor.

This appears to work well in single-threaded settings. However, in multi-threaded settings, this occasionally leads to deadlock, due to a variant of this issue. In general, forking multi-threaded processes in python is known to be quite fragile.

I am looking for a workaround. Is there a good way to torch.jit.trace().save() in a loop without deadlocking and without incurring a memory leak in a multi-threaded application?

If the answer is, “upgrade to pytorch 2.0 and use torch.compile()”, I’m ok with that. Anything that works.

That might indeed be the recommendation here, since TorchScript is in maintenance mode and will thus not receive any major updates or big fixes anymore. I’m also hesitant to look for other workarounds as the risk of running into a new issue, that won’t be fixed, will still be there.

1 Like

Thanks for the reply. Guess it’s time to upgrade!

1 Like

I started to look into the torch 2.0 documentation a bit, and am a bit confused. The latest page on exporting models to c++ appears to be this one, which still talks about using TorchScript. But I understood your answer to mean that upgrading to torch 2.0 will make it so that TorchScript won’t be needed anymore to export a model to c++. What gives?

Is the recommended approach to export to c++ to export in ONNX format, now? And if so, on the c++ side, to use a non-torch library like ONNX Runtime?

There was in fact a simple fix to my original problem. Instead of doing the torch.jit.trace() call in a separate process by using futures.ProcessPoolExecutor, do it with subprocess, pickling the arguments to disk and then reloading in the subprocess. This gets around the multi-threaded + fork sensitivities.