Using shared memory to share model across multiprocess leads to memory exploded

Hello, I am a newbie of Pytorch, currently having a reinforcement learning task, and I want to share the model among N processes on a single machine.

One of the processes is responsible for updating the weights of the model, and the other N-1 processes use the model for inference (i.e. the actor in actor-critic).

This design is adopted because the inference process needs to perform some operations on the CPU (env.step()) at the same time, so I hope to improve efficiency by detouring Python’s GIL lock using multi-processes.

After I found out that whether I put the model on CPU or GPU, as long as I use the shared memory mechanism, the consumption of RAM’s and GPU’s memory increases linearly with the number of processes. It looks like each child process makes a copy of the model.
And when using a shared model, for each child process, deploying the model on cuda takes more memory than deploying it on the CPU.

In my example, if I use spawn to create child processes, when the model is deployed on the CPU, the rss usage per process is about 1.24 GB; when deployed on the GPU, the rss usage per process is about 3.40 GB.

In addition, when deploying the model on the GPU, each child process will spawn an additional child process occupying 0.94 GB rss.

All of these leads to the poor scalability of our project.

If I use the message passing mechanism instead of shared memory, i.e. the main process not only performs update, but also performs inference for other sub-processes,and synchronizes with other sub-processes through message passing, it will bring a lot of I/O consumption, and the bottleneck of the whole system becomes the CPU of main process.

So here I have two main questions:

  1. Why does the total memory usage increase linearly with the number of processes when using shared memory? Is my code wrong or there exist other problems?
  2. As seen above, shared memory and message passing each have their own issues, is there a way to bypass these issues to achieve the result I want?

I put the code of my experiments here,
and the results of exps are in the file
Theses exps based on the example mnist_hogwild of Pytorch

I will release more details(e.g. screenshots ) if needed

Thanks a lot.