How does monobeast style multiprocessing work exactly? How should I integrate jit compilation?

I’m doing policy gradient with a modified monobeast. That means the actors are torch.mp spawned processes and the target function they run receives the actor net as an argument. Before this, actor_net.share_memory() is called.

I don’t understand what is really happening. I assume the actor’s weights only exist in one location in memory, but each worker process is mostly self contained and has its own python object representing the actor.
So then multiple CPU threads are able to do inference in parallel? When performing inference, the threads are all reading from the exact same memory for the model’s params? Could this perhaps be a bottleneck depending on the number of workers. Also, the main process will occasionally update the actor’s params to match the learners actor_model.load_state_dict(model.state_dict()) and it does so without any locking. How is this possible?
Within each of my worker processes, I have multiple threads working that use the actor net. I feel safe knowing that the threads cannot try to access the net at the same time because of the GIL. However this is not true in the multiprocessing case, right.

Secondly, I tried torch.jit.trace on my custom net which uses some python iteration and I got a 33% speedup. I would like to use a compiled actor net. Would it be as simple as compiling, calling shared memory on the compiled net, and using that instead?

Edit: load_state_dict()'ing a compiled net and passing one (with torch.mp not mp) doesnt throw errors. It remains to be seen if it works identically