Problem:
My DDP program always hangs when performing backward for some specific inputs.
Background:
Simply put, I have a CNN network called ‘net_cnn’, and a MLP network called ‘net_mlp’.
The inputs are first feed into ‘net_cnn’, generating outputs called ‘out_cnn’. ‘out_cnn’ is then feed into ‘net_mlp’, generating outputs called ‘out_mlp’.
‘loss1’ is computed using ‘out_cnn’. ‘loss2’ is computed using ‘out_mlp’.
A sketch map is shown below.
The question:
For some specific inputs, ‘out_mlp’ may output all zeros, and have nothing to do with ‘out_cnn’ and ‘net_mlp’. This makes ‘loss2’ has no gradients related to the parameters of ‘net_cnn’ and ‘net_mlp’. When using DDP with multiple GPUs, if one of these GPUs has the inputs mentioned above, the others will wait for the gradients when performing backward, which causes the problem of hanging.
So is there any solution for this kind of problem? Thanks for any suggestions.