How does DistributedDataParallel handle ignore classes when averaging gradients?

How does gradient averaging work in DistributedDataParallel training? I am particularly interested in what happens when the batches have masked or ignored data, e.g. with semantic segmentation.

For example: let’s say I have 4 GPUs and I am training a semantic segmentation network with a dataset with an ignore class. As I understand it, in the DataParallel setting, the outputs are aggregated on GPU0, the loss computed, and then the gradient is backpropagated back through each GPU’s model. In the DistributedDataParallel case, L0, L1, L2, L3 are each computed for each GPU’s share of the batch, the losses are backpropagated back through their respective GPU’s model, and the gradients along the way are averaged.

Using DataParallel, the presence of an ignore class makes no difference. Even if one GPU’s mini-batch has a lopsided amount of ignore pixels, the loss is computed as the weighted average. However, what happens when you have a lopsided distribution of ignore pixels on one GPU using DistributedDataParallel? There does not seem to be any mechanism for weighting the average of the gradients. Yet in this case, L0, L1, L2, and L3 ought to have their contributions weighted by the ratio of valid pixels when averaging gradients during backpropagation.

Is there some way to handle this ignore class imbalance during distributed training?

2 Likes

How does gradient averaging work in DistributedDataParallel training?

  1. Every DDP instance will have its own copy of the local model, and DDP will setup post autograd hooks on every parameter (i.e., DDP hooks).
  2. In every forward pass, DDP feeds the input data to its own local model, and returns the local output to the application.
  3. The application uses the local output to compute the local loss, and calls backward on the local loss, which kicks off the local autograd engine to compute gradient for all parameters. When one local gradient becomes ready, that will trigger the corresponding DDP hook. The DDP hook will run allreduce on the given gradients, and write the averaged grads back to the parameter.grad field.
  4. When backward is done, parameter.grad should all be globally averaged gradients. Optimizer can then consume that grad to update parameters.

Is there some way to handle this ignore class imbalance during distributed training?

DDP simply averages (sum and then divide by the number of DDP world size) all local gradients. So, it should work as long as the ignored data do not contribute to the local gradients.