FusedLAMB optimizer, fp16 and grad_accumulation on DDP

I am training a BERT model using PyTorch and after endless research on different versions I can’t be sure which should be the correct implementation of DDP (DistributedDataParallel).

I am working in a world_size = 8. 1 node and 8 GPUs. As far as I understand, DDP spawns one process per rank and trains the same model on different parts of the batch data. Then computes the gradient and performs a reduce of all of the gradients to update the model to each GPU again.

However, what happens with the loss…? should I compute a loss.mean() or does the only thing matter, in that case, is the gradients of each model?

I can’t find a good example where my desired specificities (torch-based mixed-precision, apex FusedLAMB optimizer and DDP) are implemented and it’s hard to know if my implementation is good.

Any resource would be really helpful!

Here is a small summary in the code I have:

model = model.to(device)
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]   

optimizer = FusedLAMB(optimizer_grouped_parameters,
                             lr=args.learning_rate)

scheduler = get_polynomial_decay_schedule_with_warmup(optimizer=optimizer,
                                                             num_warmup_steps=args.warmup_steps,
                                                             num_training_steps=t_total,
                                                             power=0.5)
fp16_scaler = torch.cuda.amp.GradScaler(enabled=True)

tr_loss, logging_loss = 0.0, 0.0   

model.zero_grad()
train_iterator = trange(
    epochs_trained, int(args.num_train_epochs), desc="Epoch"
)

for _ in train_iterator:
    epoch_iterator = tqdm(train_dataloader, desc="Iteration")
    for step, batch in enumerate(epoch_iterator):

        # Skip past any already trained steps if resuming training
        if steps_trained_in_current_epoch > 0:
            steps_trained_in_current_epoch -= 1
            continue

        inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)

        with torch.cuda.amp.autocast_mode.autocast():
            inputs, labels = inputs.to(args.device), labels.to(args.device)  
            model.train()
            outputs = model(inputs, masked_lm_labels=labels)
            loss = outputs[0]  # model outputs are always tuple in transformers

            tr_loss += loss.item()
        
        if (step + 1) % args.gradient_accumulation_steps == 0:

            fp16_scaler.scale(loss).backward()
            fp16_scaler.step(optimizer)
            fp16_scaler.update()

        loss = loss.mean()  # mean() to average on multi-gpu parallel training
        loss = loss / args.gradient_accumulation_steps


The DeepLearningExamples - BERT repository should give you a working example using these utils.

1 Like

Thank you very much for the resource @ptrblck !

Sorry to bother you again… I have one naive question about the local_rank argument. As far as I have understood, the script is computed on a process for each GPU and the dist.init_process_group is the one that handles the synchronization. However, in most of the examples I base my code on, the local rank on each script is supposed to be [-1,0].

0 one I understand is the “master” GPU which will gather everything, but the -1 local_rank what does it mean? On the other hand, I can’t also find where the local_rank argument is updated to be each script accordingly run on each GPU.

Thank you very much again for your answers!

I guess the code would set the CUDA device via:

torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)

and initialize the process group afterwards.

The args.local_rank is set by the torch.distributed.launch call which passes these arguments (or sets the env variables).

1 Like