Should we split batch_size according to ngpu_per_node when DistributedDataparallel

I agree with all your analysis on the magnitude of the gradients, and I agree that it depends on the loss function. But even with MSE loss fn, it can lead to different conclusions:

  1. If the fw-bw has processed 8X data, we should set lr to 8X, meaning that the model should take a larger step if it has processed more data as the gradient is more accurate. (IIUC, this is what you advocate for)
  2. If the gradient is of the same magnitude, we should use 1X lr, especially when approaching convergence. Otherwise, if we use 8X lr, it is more likely to overshoot and hurt converged model accuracy.

After reading your analysis, I realized that, with MSE loss fn, the discussion is mostly irrelevant to DDP. The question would then be, if I increase batch size by k, how should I adjust the learning rate, which is an open question. :slight_smile:

3 Likes