DDP and parameters sync

I have a question regarding how the parameter synchronization works in DDP.

I am training a Variational Autoencoder in a distributed setting with 4 GPUS, using torchrun.

I am also using AMP (don’t know if it may be related).

My main method looks as follows (removed the logging/prints):

# Get data loaders.
train_loader, val_loader, train_augmentations = prepare_data(LOCAL_RANK, WORLD_SIZE, args.data_path, config)

# create model
model = AutoEncoder(config['autoencoder'], config['resolution'])

# load checkpoint if resume
if args.resume_from is not None:

    checkpoint = torch.load(args.resume_from, map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])
    model = model.to(LOCAL_RANK)
    init_epoch = checkpoint['epoch']
    optimizer_state = checkpoint['optimizer']
    grad_scaler_state = checkpoint['grad_scaler']

else:
    init_epoch = 0
    optimizer_state = None
    grad_scaler_state = None

    model = model.to(LOCAL_RANK)

# find final learning rate
learning_rate = float(train_conf['base_lr'])
min_learning_rate = float(train_conf['min_lr'])
weight_decay = float(train_conf['weight_decay'])
eps = float(train_conf['eps'])

# ddp model, optimizer, scheduler, scaler
ddp_model = DDP(model, device_ids=[LOCAL_RANK])

optimizer = torch.optim.Adamax(ddp_model.parameters(), learning_rate, weight_decay=weight_decay, eps=eps)
if optimizer_state is not None:
    optimizer.load_state_dict(optimizer_state)

# Note: custom implementation of schedulers
total_training_steps = len(train_loader) * train_conf['epochs']
if train_conf['warmup_epochs'] is not None:
    warmup_steps = int(len(train_loader) * train_conf['warmup_epochs'])
    scheduler = LinearCosineScheduler(0, total_training_steps, learning_rate,
                                      min_learning_rate, warmup_steps)
else:
    scheduler = CosineScheduler(0, total_training_steps, learning_rate, min_learning_rate)

grad_scaler = GradScaler(2 ** 10)  # scale gradient for AMP
if grad_scaler_state is not None:
    grad_scaler.load_state_dict(grad_scaler_state)

for epoch in range(init_epoch, train_conf['epochs']):

    # Training
    ddp_model.train()

    for step, x in enumerate(tqdm(train_dataloader)):

        x = augmentations(x[0])

        # scheduler step
        lr = scheduler.step(global_step)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # use autocast for step (AMP)
        optimizer.zero_grad()
        with autocast():

            # forward pass
            reconstructions, kl_terms = ddp_model.module(x)

            # reconstruction loss
            rec_loss = criterion(reconstructions, x)

            # compute final loss
            loss = torch.mean(rec_loss + kl_terms)

        grad_scaler.scale(loss).backward()
        grad_scaler.step(optimizer)
        grad_scaler.update()

    # Validation
    dist.barrier()
    ddp_model.eval()
    
    [...]
    
    # Save checkpoint (after validation)
    if WORLD_RANK == 0:

        checkpoint_dict = {
            'epoch': epoch,
            'configuration': config,
            'state_dict': ddp_model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
            'grad_scaler': grad_scaler.state_dict(),
        }
        ckpt_file = f"{args.checkpoint_base_path}/{args.run_name}/epoch={epoch:02d}.pt"
        torch.save(checkpoint_dict, ckpt_file)

The training is apparently correct, the loss nicely decreases. However, after saving and re-loading the model (e.g. to continue training) I notice that the loss starts much higher than where it stopped.
E.G. if the loss at epoch 10 is 1.2345, I save the model, re-load, and continue from epoch 11, the loss will be much higher (like 8.1234).

After a lot of research, I found that the parameters slightly differ across ranks! Thus, when saving the model taking the copy of RANK 0, the saved model is not the same as the one on other ranks!

To solve the issue, I need to manually sync parameters after each training epoch:

[...]
for epoch in range(init_epoch, train_conf['epochs']):

    # Training
    [...]

    for p in ddp_model.module.parameters():
    dist.all_reduce(p.data, op=dist.ReduceOp.SUM)
    p.data /= WORLD_SIZE

    # Validation
    [...]

By doing so, re_loading the model causes the loss to continue from approximately the same value, without exploding.

The question is: when is manual sync of parameters necessary in the case of DDP?

Can it be caused by some custom modules in the “Autoencoder” object ?

Thank you for reading

Warpping the model into DDP will synchronize all parameters across the ranks. I’m unsure how you are restoring the model, but you might want to load the state_dict before wrapping the model into DDP. A manual parameter sync should not be needed.

Hi @ptrblck, thank you for the reply.

I restore the model as follows:

I first load the state_dict on each rank, and then wrap it with DDP.
Is this correct?

Moreover, note that I sync the parameters after each training epoch. I ran a test and noticed that after each epoch, the parameters on the different ranks are different. So I guess there is something inside the training procedure that gives out-of-sync parameters for some reason.

This is indeed weird as the DDP workflow would:

  • scatter the state_dict during the DDP initialization making sure the state_dicts of the models on all ranks are equal
  • execute the forward pass, calculate the loss
  • execute the gradient pass, computing the gradients and reducing them to all ranks
  • now all ranks have the same model.state_dict() as well as the same gradients
  • update the parameters via optimizer.step()
  • all models have again the same (updated) parameters
  • repeat

I thus don’t know where exactly the models diverge, but you could debug it by printing the parameters at different points during the training.

Hi @ptrblck, I followed your suggestion and tried some prints inside the training method. Now, The forward + backward pass looks as follows:

for step, x in enumerate(tqdm(train_dataloader)):

    x = augmentations(x[0])
  
    # scheduler step
    lr = scheduler.step(global_step)
  
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
  
    # use autocast for step (AMP)
    optimizer.zero_grad()
    with autocast():
  
        # forward pass
        reconstructions, kl_terms = ddp_model.module(x)
  
        # reconstruction loss
        rec_loss = criterion(reconstructions, x)
  
        # compute final loss
        loss = torch.mean(rec_loss + kl_terms)
    
    # TEST (BEFORE BACKWARD)
    flag = False
    count = 0
    for n, p in model.named_parameters():
        a = p.detach().clone()
        dist.all_reduce(p.data, op=dist.ReduceOp.SUM)
        p.data /= WORLD_SIZE
        if (a.data - p.data).sum().item() > 0:
            flag = True
            count += 1
    if flag:
        raise RuntimeError(f"BEFORE BACKWARD: {count} parameters diverge!")

    grad_scaler.scale(loss).backward()
    grad_scaler.step(optimizer)
    grad_scaler.update()

    # TEST (AFTER BACKWARD and UPDATE)
    flag = False
    count = 0
    for n, p in model.named_parameters():
        a = p.detach().clone()
        dist.all_reduce(p.data, op=dist.ReduceOp.SUM)
        p.data /= WORLD_SIZE
        if (a.data - p.data).sum().item() > 0:
            flag = True
            count += 1
            if count == 1:
                print(f"AFTER BACKWARD: {n} is different across ranks!")
                print(f"SUM OF DIFFERENCES: {(a.data - p.data).sum().item()}")
    if flag:
        raise RuntimeError(f"AFTER BACKWARD: {count} parameters diverge!")

Where I check if parameters diverge two times (after forward and after backward + update).

The above code prints the following:

[INFO] Epoch 1/800                                                                                                                 
  1%|▍                        | 1/195 [00:19<1:01:47, 19.11s/it]
AFTER BACKWARD: const_prior is different across ranks!                                                                             
SUM OF DIFFERENCES: 1.7440303054172546e-05 
                                                                                        
AFTER BACKWARD: preprocessing_block.init_conv.weight is different across ranks!
SUM OF DIFFERENCES: 3.7593381421174854e-06  

AFTER BACKWARD: preprocessing_block.init_conv.weight
 is different across ranks!                                                                                                        
SUM OF DIFFERENCES: 7.4539275374263525e-06                                                                                                                               
                                                                                                                                                                              
AFTER BACKWARD: preprocessing_block.init_conv.log_weight_norm is different across ranks!                                           
SUM OF DIFFERENCES: 5.960464477539062e-07                                                                                          
                                                                             
Traceback (most recent call last):                                                                                                 
  [...] 
RuntimeError: AFTER BACKWARD: 843 parameters diverge!                                                                              
RuntimeError: AFTER BACKWARD: 883 parameters diverge!
RuntimeError: AFTER BACKWARD: 788 parameters diverge!                         
RuntimeError: AFTER BACKWARD: 837 parameters diverge!

Apparently, there is something wrong in the gradients computations / parameters updates when using AMP. The error is very small (between 1e-05 and 1e-07), but enough to cause an inconsistency when saving + re_loading the model.

I have no idea how to solve this, but I remain open to further suggestions…

Thank you

No, I don’t think these small numerical mismatches are concerning as they are expected if non-deterministic algorithms are picked. If differences in ~1e-5 cause significant divergence in your model, a single GPU run would already suffer from it in FP32 and you might need to use float64.