Running nn.DataParallel model with occasionally missing losses

Hi all,

I am trying to implement a model that runs with multigpu training and DataParallel. The problem is that due to the nature of my model, occasionally there will be forward paths where losses are not produced.

Problem:
Without data parallel, I simply set these missing losses to None or 0, and only add losses that are not None or 0 to my total loss. However, moving on to multigpu, there are instances where 1 process on GPU 0 will produce an actual loss, while the process on GPU1 will not produce that particular loss.

Hence, when DataParallel module gathers these losses, bad things happen – it cannot combine the losses from two GPU together.

Things I tried:
I tried making the missing loss 0, a float tensor of 0, or setting it to None. None of these methods work – they each produce an error in torch/nn/parallel/scatter_gather.py. For instance, setting the missing loss to 0, which is not iterable, will give “TypeError: zip argument #1 must support iteration” error.

Question:
I am wondering if there is a particular thing in DataParallel that ignores missing entries when combining losses from multiple GPUs? Or is there a possible solution to this problem without altering pytorch code?

Any help is greatly appreciated!