PyTorch multi-processing while syncing the model weight

I am training an RL model, which consists of two steps:

  • Collecting data by letting the RL agent interact with the simulation environment
  • Training the RL agent (policy network) on the collected data

During the entire training process, we will do the above 2 steps iteratively for multiple rounds. Now the issue is, the data collection part involves some post-processing (using a big pre-trained segmentor to generate object masks on the collected images), which is very slow.

Currently, I’m thinking of having 1 GPU training the RL agent (GPU1), and another 1 GPU dedicated to the data collection + segmentation part (GPU2). The only communication between the 2 GPUs is syncing the model weight, because we always want to collect data using an up-to-date model weight (i.e. update the GPU2 weight every 100 training steps on GPU1). We don’t need to send the collected data to GPU1 because we plan to write them to disk, and the dataloader on GPU1 should be able to read them from disk (i.e. reading all the files presented in a folder, regardless of when it’s generated).

At first, I was thinking of using PyTorch DDP, as it handles all the low-level communication stuff and can sync model weight across GPUs. However, later on, I realized the sync is achieved by syncing the gradients during the backward pass. In my case, I only compute loss/gradient on GPU1, while GPU2 is doing different operations, which will lead to a deadlock.

Now I think probably I should look into the multi-processing code to somehow manually sync the model weight. Are there any resources I should look at (I’m looking at torch.multiprocessing). Since the only thing I want to sync here is the model weight, I guess it won’t be very complex.
(Well, worst case I can also save the model weight to disk, and load it on GPU2 whenever there is a new weight. Then there is no communication between GPUs)

OK we finally decide to use torch.multiprocessing to send sync the model weight. Turns to be quite fast