Logging loss value in DDP training

Hello,

I am reviewing the pytorch imagenet example in the repos and I have trouble comprehending the loss value that is returned by the criterion module. In Line 291, is the loss that is recorded later for only one process? Is summing and averaging all losses across all processes using ReduceOp.SUM a better alternative? For example, when I want to save my model or simply log the metric, I would like to do it based on the average loss values across all processes.

In other words, is losses.update(loss.item(), images.size(0)) only saving the loss value for one process?

1 Like

Hey @amirhf, if you are using DistributedDataParallel, yep, this is the local loss within one process. And different process can have different loss values.

Is summing and averaging all losses across all processes using ReduceOp.SUM a better alternative?

This will give you the global loss, but will also introduce one more communication per iteration. If this is just for logging purpose, will it be sufficient if the logging is done every n iterations, so that there will be a smaller amortized comm overhead?

2 Likes

Hi @mrshenli,

Thank you for your response! So the concern is that the reduce operation is an overhead. Yes, so I would just log the global loss every few iterations. Something along the lines of if iteration % 200 == 0: reduce and log. Is that going to be okay?

1 Like

Hi @amirhf,

I synchronize the loss record at every epoch, which is about 100 ~ 1500 iterations varying by the dataset I use. I didn’t see much performance degradation with that. The gains in DistributedDataParallel was bigger (compared to DataParallel) than the loss in the GPU communication time in my experience.

You can refer to my code here.

1 Like

Yep, that should be infrequent enough.

There are also ways to hide this communication delay by setting async_op=True when launching the dist.reduce, and only wait on the returned handle after the next forward pass. This would allow the communication of dist.reduce to overlap with the next fw pass.

1 Like