C++ Frontend: avoid spawning extra threads in already multithreaded code

I am writing a multithreaded C++ application, in which each thread needs to perform some PyTorch computations.

Unfortunately, this seems to be quite hard. I found that:

  • if I do not do anything special, PyTorch will (via OpenMP) create a process pool spanning all available cores, and all my threads will dispatch operations to this pool
  • if I try to use at::set_num_threads(1) (either globally at the beginning, or within each thread at spawn), a single OpenMP process is spawned by PyTorch, which will then receive the operations dispatched from all my threads.

In both cases, this does not give me the control I need in my scenario. I need that the PyTorch operations are executed by the thread that dispatched them. So if one of my threads wants to add two tensors together, the addition should be performed by that thread, without dispatching via OpenMP.

What are the options available to achieve this? Until now, the only idea I have is to manually build libtorch from source, using the environment variable MKL_THREADING=SEQ. However, this might take quite a lot of time and effort to set up correctly, plus I’m not sure it will work.

So, did anyone find themselves in this situation? Any other idea or workaround?