I am working with DDP and I have a doubt about the loss computation per process. Currently, using DDP, I have the possibility to distribute the batch among different processes, in this way I can increase the size of each batch. For instance, if I have a batch of 128 and I use 2 processes, I will end up having an effective batch size of 128*2=256. For simplicity, I will refer to local batch (i.e., the batch seen by a single process, which is 128) and global batch (i.e., the batch seen by the entire DDP, which is 256)
The doubt I have is the following. When I compute the loss for each process, this loss is averaged on the local batch and not on the global batch, thus resulting in gradient computation that depends on the local batch. When I compute the
loss.backward(), DDP will raise hooks each time all gradients for a bucket are ready and average them among all processes. Anyway it is not clear whether DDP re-adjust the loss (divide the total loss for the global batch) or it is something I need to take care of.