Optimizer.step() hangs on linux; multiprocesssing

Background:
I’m doing a distributed PPO (basically gathering data from several worker and training on one learner)

Issue:
Data collection works fine but when I train the network with lines below

            self.critic.optimizer.zero_grad()
            batch_states_values = self.critic.forward(batch_states)
            print('crtitic batch_states_values done')
            critic_loss = F.mse_loss(batch_states_values, batch_REFs)
            print('crtitic critic_loss done')
            critic_loss.backward()
            print('crtitic loss backward done')
            self.critic.optimizer.step()
            print('crtitic step done')

And the output shows:

crtitic batch_states_values done
crtitic critic_loss done
crtitic loss backward done

So it appears to be that the program hangs after the loss backward. What could be the cause?
It works fine on my windows workstations but hangs when I run it on a linux machine

Hey @Lewis_Liu, which part of the program is distributed? Since torch.distributed does not support Windows yet, I assume the working version of the program on Windows does not use distributed training?

BTW, which optimizer are you using?

The training isn’t distributed and torch.distributed isn’t used.

By distributed I mean the workers used to collect data are distributed and the network params are send from trainer to these workers through mp.queue.

Once the data are collected and trainer starts to train, the workers stop working so I suppose there’s no interaction between the workers and the trainer. So what appears really strange to me is that the backward is done but step is not.

I’m using the standard optim.Adam

Are you using any CUDA ops? If so, could you please add a torch.cuda.synchronize() before every print to make sure that preceding ops are indeed done instead of still pending in the CUDA stream?

Hi Li,

I just added the line and the prints are the same.

FYI, after it hangs there, I killed the program and it showed this. I’m not sure if this is helpful

Traceback (most recent call last):
File “test.py”, line 23, in
Process Process-2:
p.join()
File “/apps/software/Python/3.7.4-GCCcore-8.3.0/lib/python3.7/multiprocessing/process.py”, line 140, in join
Process Process-1:
res = self._popen.wait(timeout)
File “/apps/software/Python/3.7.4-GCCcore-8.3.0/lib/python3.7/multiprocessing/popen_fork.py”, line 48, in wait
return self.poll(os.WNOHANG if timeout == 0.0 else 0)
File “/apps/software/Python/3.7.4-GCCcore-8.3.0/lib/python3.7/multiprocessing/popen_fork.py”, line 28, in poll
pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt

That’s weird, is there a way that we can reproduce this locally, so that we can help debug?

Hi Li,

Thanks for the help. I’m afraid it wouldn’t be easy to do it. I’ll try to find a way to convert it so it can be shared.

Meanwhile, what would you say that might be the cause? Any chance this could be the use of mp.queue or mp.Value in the linux environment? If likely, I can try to avoid or alter the way using them

Are you using torch.multiprocessing.SimpleQueue? If yes, does the program guarantee that the owner of the shared data object is still alive when the user uses it when sharing CPU tensors? And are you using spawn to create processes?

Unlike CPU tensors, the sending process is required to keep the original tensor as long as the receiving process retains a copy of the tensor. The refcounting is implemented under the hood but requires users to follow the next best practices.

I used spawn by adding the line
mp.set_start_method(“spawn”, force=True)

I used torch.multiprocessing.Queue instead of the SimpleQueue. The owners are always alive

Spawn on windows but not on cluster

Interesting fact, on linux cluster, it works if I change the device to ‘CPU’ instead of using a GPU. But on windows, both devices work

Have you solved this problem? I have met the same problem when running with multiple machine

Not completely solved but was able to found what was the issue and found a way around. The issue is that the network was somehow shared with other processes. So my practical suggestion would be check everything that might lead to your network being shared/visited. e.g. mistake in using copy.copy or deepcopy to send the statedict