No, the NaN loss will cause all gradient to be NaNs on this rank as well. The gradient synchronization will then create NaN gradients on all ranks and the scaler.step(optimizer) call will skip this update. The following scaler.update() step will decrease the scaling factor.
1 Like