Got stucks while loading a big tensor in the subprocess

Environment

  • OS: Linux Mint 21 (based on Ubuntu 22.04)
  • Python: 3.9.15
  • PyTorch: 2.0.0

Description

I’m currently working on an asynchronous reinforcement learning algorithm, and I want to create multiple processes for agents interacting with the environment and collecting the trajectories into a replay buffer.
In each subprocess, The agent will periodically load the learned and shared actor parameters to collect data.
However, the issue is that the subprocess will hang if the agent loads a shared actor with a large tensor.
How could I solve this issue?

Below is a simple example:

class FakeAgent:
    def __init__(self):
        self._actor = nn.Sequential(
            nn.Linear(512, 512, bias=False)
        )

def collect_traj(shared_actor):
    pid = os.getpid()
    agent_behav = FakeAgent()

    # [START] Repeat for some fixed number of iterations
    print(f"pid {pid}: Start loading shared model ...")
    agent_behav._actor.load_state_dict(shared_actor.state_dict())
    print(f"pid {pid}: Model loaded.")
    
    # collect trajectories here using agent_behav
    # ...
    
    # [END]   Repeat for some fixed number of iterations


if __name__ == '__main__':
    agent = FakeAgent()
    agent._actor.share_memory()
    
    ppool = []
    for _ in range(2):
        p = mp.Process(target=collect_traj, args=(agent._actor,))
        ppool.append(p)

    for p in ppool:
        p.start()

    for p in ppool:
        p.join()

    print("All process done.")