Best way to debug DistributedDataParallel code?

Hey,

I’m having an issue that my code randomly hangs when using DistributedDataParallel.
It is completely random when this occurs, and it does not always occur.
I suspect that it has something to do with the DistributedDataParallel as out of the 4 gpu’s I’m using, 3 are reporting to be using 100% of that gpu and 1 is completely idle.

What is the best way for me to debug what is going on as I’m getting no errors?

1 Like

Looks like the program somehow desynchronized (different processes wants to sync different amount of parameters or run different numbers of iterations). Unfortunately, DDP does not have a debug mode for now. Can you share some code snippet? Does your code tries to handle any errors in backward pass on its own (say catch OOM and rertry)? Does all processes see exactly the same amount of input data?

I’ve had to deal with similar issues, you should feel lucky you only have 4 processes and not 64 :slight_smile:

The fact that 3 are in 100% utilization means that are inside nccl sync operation like the one at the end of backward .backward(), while the 4th one is doing something non-GPU related, like waiting for user input.

The general strategy is to look at stack trace of the “odd one out” process. You can get C++ stack trace by doing "gdb -p " and “bt” or “thread apply all bt”

With a bit more work, you can get Python stack. This requires modifying client code. For instance, I run install_pdb_hander on init of my processes. It allows me to break into PDB on CTRL+\ and look what the current process is doing.

When using distributed launcher, this will only send CTRL+\ to the launcher process, so for this to get you stack trace of arbitrary worker you could need to modify your launching procedure to run in 4 worker in 4 different tmux windows

1 Like

related issue https://github.com/pytorch/pytorch/issues/27757