Lewis_Liu
(Lewis Liu)
July 17, 2020, 8:17pm
1
I’m trying to share the weights of networks between processes using multiprocessing manager.dict
The code is as follows
for name in critic_state_dict:
self.shared_networks.critic[name] = T.tensor(critic_state_dict[name].clone().cpu().detach().numpy())
This works fine in windows. But when I use a cluster, it hangs in the middle of the for loop
How do I fix this? Or if I want to periodically share the weights among processes, how to do it properly?
Thanks
mrshenli
(Shen Li)
July 18, 2020, 3:49pm
2
Hey @Lewis_Liu ,
Did you use fork or spawn?
Or if I want to periodically share the weights among processes, how to do it properly?
One solution is to create a multiprocessing queue , and pass that queue to child processes. Then, in the loop, use that queue to pass shared tensors. The test below can serve as an example:
@classmethod
def _test_allgather_process(
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c):
pg = init_pg(rank, filename, world_size)
xs = [shared_tensors[rank]]
ys = [[torch.zeros_like(xs[0]) for i in range(world_size)]]
pg.allgather(ys, xs).wait()
for i in range(world_size):
c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu")))
p2c.get()
@unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
def test_shared_allgather_gloo(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allgather_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_gloo,
self.world_size)
Lewis_Liu
(Lewis Liu)
July 18, 2020, 4:06pm
3
Hi Li,
I switched to using Queue. But I cannot avoid firstly getting the net tensors from state_dict right?
I believe it’s spawn on my windows workstation and it’s fork on the linux cluster if i’m correct
mrshenli
(Shen Li)
July 18, 2020, 4:09pm
4
I don’t have the full context here. Can you let the processing holding the state_dict be the writer to the queue?
Lewis_Liu
(Lewis Liu)
July 18, 2020, 4:23pm
5
Yep.
The network is trained and updated for a step. After this, the process has only one sole task that is to write the state_dict into the queue. Other processes doesn’t have direct access to the network except through the queue