I have a question regarding how the parameter synchronization works in DDP.
I am training a Variational Autoencoder in a distributed setting with 4 GPUS, using torchrun.
I am also using AMP (don’t know if it may be related).
My main method looks as follows (removed the logging/prints):
# Get data loaders.
train_loader, val_loader, train_augmentations = prepare_data(LOCAL_RANK, WORLD_SIZE, args.data_path, config)
# create model
model = AutoEncoder(config['autoencoder'], config['resolution'])
# load checkpoint if resume
if args.resume_from is not None:
checkpoint = torch.load(args.resume_from, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
model = model.to(LOCAL_RANK)
init_epoch = checkpoint['epoch']
optimizer_state = checkpoint['optimizer']
grad_scaler_state = checkpoint['grad_scaler']
else:
init_epoch = 0
optimizer_state = None
grad_scaler_state = None
model = model.to(LOCAL_RANK)
# find final learning rate
learning_rate = float(train_conf['base_lr'])
min_learning_rate = float(train_conf['min_lr'])
weight_decay = float(train_conf['weight_decay'])
eps = float(train_conf['eps'])
# ddp model, optimizer, scheduler, scaler
ddp_model = DDP(model, device_ids=[LOCAL_RANK])
optimizer = torch.optim.Adamax(ddp_model.parameters(), learning_rate, weight_decay=weight_decay, eps=eps)
if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state)
# Note: custom implementation of schedulers
total_training_steps = len(train_loader) * train_conf['epochs']
if train_conf['warmup_epochs'] is not None:
warmup_steps = int(len(train_loader) * train_conf['warmup_epochs'])
scheduler = LinearCosineScheduler(0, total_training_steps, learning_rate,
min_learning_rate, warmup_steps)
else:
scheduler = CosineScheduler(0, total_training_steps, learning_rate, min_learning_rate)
grad_scaler = GradScaler(2 ** 10) # scale gradient for AMP
if grad_scaler_state is not None:
grad_scaler.load_state_dict(grad_scaler_state)
for epoch in range(init_epoch, train_conf['epochs']):
# Training
ddp_model.train()
for step, x in enumerate(tqdm(train_dataloader)):
x = augmentations(x[0])
# scheduler step
lr = scheduler.step(global_step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# use autocast for step (AMP)
optimizer.zero_grad()
with autocast():
# forward pass
reconstructions, kl_terms = ddp_model.module(x)
# reconstruction loss
rec_loss = criterion(reconstructions, x)
# compute final loss
loss = torch.mean(rec_loss + kl_terms)
grad_scaler.scale(loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
# Validation
dist.barrier()
ddp_model.eval()
[...]
# Save checkpoint (after validation)
if WORLD_RANK == 0:
checkpoint_dict = {
'epoch': epoch,
'configuration': config,
'state_dict': ddp_model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'grad_scaler': grad_scaler.state_dict(),
}
ckpt_file = f"{args.checkpoint_base_path}/{args.run_name}/epoch={epoch:02d}.pt"
torch.save(checkpoint_dict, ckpt_file)
The training is apparently correct, the loss nicely decreases. However, after saving and re-loading the model (e.g. to continue training) I notice that the loss starts much higher than where it stopped.
E.G. if the loss at epoch 10 is 1.2345, I save the model, re-load, and continue from epoch 11, the loss will be much higher (like 8.1234).
After a lot of research, I found that the parameters slightly differ across ranks! Thus, when saving the model taking the copy of RANK 0, the saved model is not the same as the one on other ranks!
To solve the issue, I need to manually sync parameters after each training epoch:
[...]
for epoch in range(init_epoch, train_conf['epochs']):
# Training
[...]
for p in ddp_model.module.parameters():
dist.all_reduce(p.data, op=dist.ReduceOp.SUM)
p.data /= WORLD_SIZE
# Validation
[...]
By doing so, re_loading the model causes the loss to continue from approximately the same value, without exploding.
The question is: when is manual sync of parameters necessary in the case of DDP?
Can it be caused by some custom modules in the “Autoencoder” object ?
Thank you for reading