That reason for this is the loss calculation which will take place on the default device.
As a small side note, this is also the reason you see a bit more memory consumption on one device.
You could skip it by adding your loss directly into your model and calculate the losses on each replica.
Then only the (scalar) loss value has to be reduced.
1 Like