Multiprocessing shared memory dramatically slows NN evaluation?

I am trying to follow the pytorch multiprocessing docs to set up a basic worker pool of processes playing games for my RL algorithm. The game playing only requires forwards passes on my neural net, so I don’t need gradients (I’m using torch.no_grad()).

I’m calling model.share_memory() before passing my model to the workers using torch.multiproccessing.Pool.apply_async. What I find though is that although the model is passed very quickly to the other processes, the forward pass on those other processes becomes much slower, to the point that it is always much slower than a single-process implementation.

I tried profiling my processes to see what was causing the slowdown, but it looks like the basic pytorch neural net forward passes were taking up the vast majority of the time:

 %Own   %Total  OwnTime  TotalTime  Function (filename:line)                                                                                                                                                      
 37.00%  37.00%   685.1s    685.1s   linear (torch/nn/functional.py:1676)
 22.00%  22.00%   464.5s    464.5s   layer_norm (torch/nn/functional.py:2048)
  4.00%   4.00%   166.2s    166.2s   softmax (torch/nn/functional.py:1498)
 13.00%  13.00%   165.3s    165.3s   multi_head_attention_forward (torch/nn/functional.py:4130)
  4.00%   4.00%   164.9s    164.9s   multi_head_attention_forward (torch/nn/functional.py:4108)

The slowdown gets worse and worse the more processes I add to the pool. (Overall throughput continues to go down). I also can tell that the problem isn’t that there is a one time overhead, beacuse just doing a single forward pass in each worker is not that much slower, but doing many forward passes the slowdown becomes extremely noticeable.

I have no idea what could be causing this slowdown. Is this expected that inference running from shared memory would be dramatically slower? I have also tried changing the multiprocessing context to “fork” or “spawn” to no avail.

Any updates on this? I’m sort of having the same issues

Sadly no. Ultimately what I ended up doing was moving all the NN inference to a single process, which took RPCs from the other processes, and did the inference in a single batch. This way the NN is not shared between processes, so I avoid this slowdown. However it may not work for your use case.