DistributedDataParallel loss computation

Hi,
I am working with DDP and I have a doubt about the loss computation per process. Currently, using DDP, I have the possibility to distribute the batch among different processes, in this way I can increase the size of each batch. For instance, if I have a batch of 128 and I use 2 processes, I will end up having an effective batch size of 128*2=256. For simplicity, I will refer to local batch (i.e., the batch seen by a single process, which is 128) and global batch (i.e., the batch seen by the entire DDP, which is 256)
The doubt I have is the following. When I compute the loss for each process, this loss is averaged on the local batch and not on the global batch, thus resulting in gradient computation that depends on the local batch. When I compute the loss.backward(), DDP will raise hooks each time all gradients for a bucket are ready and average them among all processes. Anyway it is not clear whether DDP re-adjust the loss (divide the total loss for the global batch) or it is something I need to take care of.

Thank you

I arrived at the conclusion that there is no need to re-adjust the loss value since averaging gradients will produce an effect that is the same as using the global batch size.

This can be proven computing the loss for each local batch. This loss is sum(l)/batch_size, where l is the list containing the losses for each sample in the local batch. The gradient for the i-th parameter wi will be the gradient of sum(l) w.r.t. to wi scaled by the local batch size (e.g., 128). When we average all the gradients of the various processes we will divide by the n_proc which results in the scaling factor being multiplied by n_gpus (e.g., 2), thus resulting in a scaling factor equal to the global batch size (e.g., 256).