Why do workers end up with different loss values?

From docs,

Constructor, forward method, and differentiation of the output (or a function of the output of this module) is a distributed synchronization point. Take that into account in case different processes might be executing different code.

I’m trying to print loss from each worker, and I’m getting the following output:

| distributed init (rank 2): tcp://localhost:1947
| distributed init (rank 3): tcp://localhost:1947
| distributed init (rank 0): tcp://localhost:1947
| distributed init (rank 1): tcp://localhost:1947
| initialized host gnode03 as rank 3
| initialized host gnode03 as rank 1
| initialized host gnode03 as rank 2
| initialized host gnode03 as rank 0
rank 0 loss 920.7410278320312
rank 1 loss 1102.2825927734375
rank 3 loss 765.515869140625
rank 2 loss 642.1211547851562
rank 2 loss 950.1659545898438
rank 1 loss 863.4507446289062
rank 3 loss 1053.586669921875
rank 0 loss 551.5623168945312
rank 0 loss 679.0967407226562
rank 2 loss 970.89892578125
rank 1 loss 1246.443359375
rank 3 loss 1169.9415283203125
rank 0 loss 798.79833984375

Does this mean I have to explicitly aggregate and average the total loss by total batch size? Or is this handled internally? The segment which prints the above looks like this:

        sample = move_to(sample, self.device)
        loss, logging_outputs = self.model(sample)
        clip_grad_norm_(self._model.parameters(), args.max_grad_norm)
        return loss.item()

Ah! This is indeed what you meant in DistributedDataParallel loss compute and backpropogation?. Pasting my answer there here as well for posterity and the indexers.

Each process computes its own output, using its own input, with its own activations, and computes its own loss. Then on loss.backward() all processes reduce their gradients. As loss.backward() returns, the gradients of your model parameters will be the same, and the optimizer in each process will perform the exact same update to the model parameters.

Note that this is only the case if you use torch.nn.parallel.DistributedDataParallel. If you don’t, you’ll need to take care of gradient synchronization yourself.