OMP parallelism over pool of modules


I am trying to train bunch of small networks concurrently, which includes calls to torch::jit::Module::forward and subsequent call to torch::autograd::grad on results afterwards. To be more precise, here’s what I am doing:

  1. For each worker thread I create a copy of original module by calling Module::clone.
  2. Within each worker thread I only use corresponding copy of a module to calculate value and gradient using thread-specific inputs.
  3. After value and gradient are calculated, they are used to update state of the model stored separately.

So each OMP worker has its own copy of a module which worker uses to process the payload.
What I noticed is that with the increase of number of workers (amount of payload for each worker remains the same) performance becomes worse, approximately like that:

1 worker(s) - 11s
2 worker(s) - 16s
3 worker(s) - 20.5s
4 worker(s) - 25s

I removed the autograd part in the loop body and it seems that performance issue still persists with approximately same law, so it is not related to some synchronization within gradient calculation, although might be related to graph construction:

1 worker(s) - 3.56s
2 worker(s) - 4.61s
3 worker(s) - 5.91s
4 worker(s) - 7.58s

To avoid potential problems with conflicting parallelization libraries, before starting this test I’ve put torch.set_num_threads(1). The whole test is done in Python, function with omp loop is loaded from pybind extension module.
What bothers me most is that launching these tasks in separate processes does not have such dramatic performance drop. However it is much easier to work with threads.

Why is this happening and what are the workarounds?