DistributedDataParallel loss compute and backpropogation?

I’m trying to get DistributedDataParallel to work on a code, using pytorch/fairseq as a reference implementation. I’m finding the implementation there difficult to comprehend. I’ve opened an issue for the same. Below is a (hopefully) complete relevant extract. The uncommented segment I’ve already got working and loss in converging.

    def train_step(self, sample):
        self.model.train()
        self._optimizer.zero_grad()
        sample = move_to(sample, self.device)
        loss, batch_sizes = self.model(sample)
        # 1: Is the below done implicitly
        #    seems to be missing in fairseq code.
        # all-gather([loss, batch_sizes])
        # loss = loss.sum()/batch_sizes.sum()
        loss.backward()
         # 2: Something similar to the following 
        #    exist. what is happening here?
        # for p in parameters-optimized:
        #      p.grad = p.grad*distributed_world_size/batch_sizes.sum()
        self._optimizer.step()
        return loss.item()

My concerns are:

  1. Shouldn’t I be doing an all gather as indicated in code? Is this done implicitly?
  2. What is happening in the second segment?
1 Like

Hi @jerinphilip,

  1. Why would you need a gather on the loss? I can see how you might think the loss aggregation is needed for distributed training but what happens is the following. Each process computes its own output, using its own input, with its own activations, and computes its own loss. Then on loss.backward() all processes reduce their gradients. As loss.backward() returns, the gradients of your model parameters will be the same, and the optimizer in each process will perform the exact same update to the model parameters.
  2. This normalizes the gradients w.r.t. the total number of processes. If you end up using torch.nn.parallel.DistributedDataParallel, this is already done for you. It is possible this is still a part of fairseq as earlier versions had a custom approach for distributed data parallelism, whereas newer versions can use the upstream wrapper directly (IIRC).
6 Likes

Hi! I need an advice. I have 4 processes/gpus with DDP. Should I implement Ioss reduction by sum (using all_reduce) before backward pass, or is it enough just for gradients to be automatically averaged by DDP? Could increasing the learningrate by a factor of x4 compensate for the division by number of gpus done by the averaging? I am trying to get a DDP run equivalent to Dataparallel.

1 Like

It is not necessary to use another allreduce to sum all loss. And additional allreduce might have considerable negative impact on training speed.

Could increasing the learningrate by a factor of x4 compensate for the division by number of gpus done by the averaging?

This is not guaranteed and the loss function itself also plays a role here. See this discussion: Should we split batch_size according to ngpu_per_node when DistributedDataparallel

I am trying to get a DDP run equivalent to Dataparallel.

There is a subtle difference between DP and DDP. IIUC, with DP, the grads from replicated models are accumulated (i.e., sum) into the param.grad field in the original model, but DDP’s gradient is averaged. Not 100% confident, but I feel if we would like to let DDP behave as similar to DP as possible, we probably should multiple DDP’s result gradient by world_size. Whether that is the same as using 4X learning rate, might depend on the optimizer algorithm.

2 Likes

Thank you! This is extremely helpful!

1 Like

I am working with fcos loss. The authors of fcos treat the case of DDP and implement reduction of the loss components inside the loss script. I should get rid of that part of their code then and do not use reduction before backward. I will use reduction just for plotting the loss values (after backward) in the training script.
Is it ok in your opinion? Thanks again!

1 Like

Yep, this should be OK.

2 Likes

Thank you for the useful explanations.

From the discussion above I understand that the reason why one shouldn’t do an all_gather sum of the losses when training Distributed Data Parallel mode is that these all gather operations can slow down the process.

Are there any other reasons why the loss tensors should not be summed other than performance reasons?

I ask this because in case the loss tensors are small, if an all_gather sum is performed when computing the losses, this will result in identical losses for all processes. Therefore gradient averaging over processes will simply divide the losses by the number of processes.

This has the advantage of mimicking the behavior of DataParallel and of providing consistent results independently of the number of processes being run without the need to adjust learning rates, batch sizes, etc.

In short, when the cost of doing an all_gather sum of the losses is low, are there any other reasons beyond performance not to do it? And isn’t the consistent behavior independently of the number of processes an advantage?

Thank you

The reason this is not sufficient is because the gradient computation depends on both loss and activation. And the activation depends on the input data, which is different in all processes. Therefore, even if loss is communicated, you will still need to communicate either gradients or activation to make sure all model parameters in all processes are consistent. Otherwise, if only communicating loss and then do backward locally, models from different processes might diverge.

That makes sense. Thank you.

Hi @mrshenli , I hope you are well? sorry, I need to show the training and validation loss on a graph, can I just get the loss from gpu:0 and show or no I need to use arll_reduce fo rboth training and validation part?

Hi @pietern , i hope you are well. sorry, I need to show the training loss and validation loss on the graph by using ddp multiple gpus. when I printing the loss in the code it shows me three loss which are different because I used 3 gpus which make sense. would you please guide me how I can all_reduce the loss for showing the graph? or can I just get the loss from gpu_id=0 and show that?

#!/usr/bin/env python
# coding: utf-8

# In[1]:


from torch.utils.data import DataLoader
from transformers import TextDataset,DataCollatorForLanguageModeling
#from transformers import  AutoModelWithLMHead
from transformers import  AutoModelForCausalLM
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import gc
import math
import os
import time
import datetime
import torch
import torch.distributed as dist
import sys
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from transformers import AutoTokenizer, GPTNeoModel,GPTNeoForCausalLM
import os
import random
import pandas as pd
import copy
######################
weight_decay=0
learning_rate=5e-5
adam_epsilon=1e-8
warmup_steps = 1e2
lr=5e-5
Max_length=400

PathData='/home//NLP_Projects/CaseSummary_resolutionProject/Results_GPT_2/model_v4200_k_bs=16_lr=5e-05_epochs=20/'
pretrained_model = '/home//GPT_NEO_1.3B/'

########################################
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

################################################
class GPT2Dataset(Dataset):

    def __init__(self, txt_list, tokenizer, gpt2_type=pretrained_model, max_length=400):

        self.tokenizer = tokenizer
        self.input_ids = []
        self.attn_masks = []

        for txt in txt_list:

            encodings_dict = tokenizer('<|startoftext|>'+ txt + '<|endoftext|>', truncation=True, max_length=max_length, padding="max_length")

            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx] 

######################################################3
def ddp_setup(rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    os.environ['CUDA_VISIBLE_DEVICES'] = "1,2,3"

    init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)
#########################################################

def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
    
    gpu_id=rank
    

    ### defined variable ###############
    seed_val = 42
    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

    ###############################
    
    ddp_setup(rank, world_size)
    
    ###############################
    
    tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model, bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>') #gpt2-small

    # model_or = GPT2LMHeadModel.from_pretrained(pretrained_model)
    model_or = GPTNeoForCausalLM.from_pretrained(pretrained_model)

    model_or.resize_token_embeddings(len(tokenizer))

    ## loading traina and tets dataset
    print(PathData)
    trains_titles=pd.read_csv(PathData+'/'+'traindata.csv')
    valid_titles=pd.read_csv(PathData+'/'+'validdata.csv')
    
    trains_titles=trains_titles.drop(columns=['Unnamed: 0'])['0']
    valid_titles=valid_titles.drop(columns=['Unnamed: 0'])['0']

    print(trains_titles.head(2))
    
    train_dataset = GPT2Dataset(trains_titles, tokenizer, max_length=Max_length)

    Val_dataset = GPT2Dataset(valid_titles, tokenizer, max_length=Max_length)
    
    ############################################################################
    
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                       batch_size=batch_size,
                                                        pin_memory=True,
                                                        shuffle=False,
                                                       sampler=DistributedSampler(train_dataset))

    validation_loader= torch.utils.data.DataLoader(dataset=Val_dataset,
                                                       batch_size=batch_size,
                                                        pin_memory=True,
                                                        shuffle=False,
                                                       sampler=DistributedSampler(Val_dataset))
    
    
    total_steps = len(train_loader) * total_epochs


    ################# define optimizer and scheduler#########################

    # Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
    optimizer = AdamW(model_or.parameters(), lr = learning_rate,eps = adam_epsilon)


    # Create the learning rate scheduler.
    # This changes the learning rate as the training loop progresses
    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps = warmup_steps, 
                                                num_training_steps = total_steps)

    ############################## train_loader and validation_loader ##################
 
    training_steps_per_epoch=len(train_loader)
    total_num_training_steps = int(training_steps_per_epoch*total_epochs)

  ######################## applying DDP on the model for training ######################
    model = copy.deepcopy(model_or)

    model=model.to(gpu_id)
    model = DDP(model, device_ids=[gpu_id])
    print("gpu_id",gpu_id)
    # ========================================
    #               Training
    # ========================================


        
    training_stats = []
           

    for epoch_i in range(0, total_epochs):
        
        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, total_epochs))
        print('Training...')

        ##########################################
        train_loader.sampler.set_epoch(epoch_i)
        b_sz = len(next(iter(train_loader))[0])
        print(f"[GPU{gpu_id}] Epoch {epoch_i} | Batchsize: {b_sz} | Steps: {len(train_loader)}")
        train_loader.sampler.set_epoch(epoch_i)
        ##########################################

        t0 = time.time()

        total_train_loss = 0

        model.train()

        for step, batch in enumerate(train_loader):

            #################################
            b_input_ids = batch[0].to(gpu_id,non_blocking=True)
            b_labels = batch[0].to(gpu_id,non_blocking=True)
            b_masks = batch[1].to(gpu_id,non_blocking=True)
            #################################

            optimizer.zero_grad()        

            outputs = model(  b_input_ids,
                             labels=b_labels, 
                              attention_mask = b_masks,
                              token_type_ids=None
                            )

            loss = outputs[0]  
            batch_loss = loss.item()
            total_train_loss += batch_loss
        #    print("total_train_loss",total_train_loss)
            loss.backward()
            optimizer.step()
            scheduler.step()

        # Calculate the average loss over all of the batches.
        avg_train_loss = total_train_loss / len(train_loader)  

        del total_train_loss
        del batch_loss
        
        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)
        print("  Training epoch took: {:}".format(training_time))

        # ========================================
        #               Validation
        # ========================================

        print("")
        print("Running Validation...")

        avg_val_loss_1=[]
        t0 = time.time()
        #################### is this section corrcet for validation  #############
        model.eval()
        model = DDP(model, device_ids=[gpu_id])
        ########################################3
        total_eval_loss = 0
        nb_eval_steps = 0
        
        ########################################
        validation_loader.sampler.set_epoch(epoch_i)
        b_sz = len(next(iter(validation_loader))[0])
        print("bz",b_sz)
        print(f"[GPU{gpu_id}] Epoch {epoch_i} | Batchsize: {b_sz} | Steps: {len(validation_loader)}")
        validation_loader.sampler.set_epoch(epoch_i)
        ###########################################
        
        # Evaluate data for one epoch
        for batch in validation_loader:

            b_input_ids = batch[0].to(gpu_id,non_blocking=True)
            b_labels = batch[0].to(gpu_id,non_blocking=True)
            b_masks = batch[1].to(gpu_id,non_blocking=True)

            with torch.no_grad():        
                outputs  = model.module(b_input_ids,attention_mask = b_masks,labels=b_labels)
                loss = outputs[0]  
            batch_loss = loss.item()
          #  print("here batch loss",batch_loss)
            total_eval_loss += batch_loss        

        avg_val_loss = total_eval_loss / len(validation_loader)
       # print("here total_eval_loss=",total_eval_loss)

        avg_val_loss_1.append(avg_val_loss)

        validation_time = format_time(time.time() - t0)    

        del total_eval_loss 

   
        gc.collect()
        
        ################### saving the model ########################

        if gpu_id == 0:
    
            Path2=Results_Path+'/'+'savemodel_epoch=='+str(epoch_i)
    
            ss=os.path.isdir(Path2)
            if ss==False:
                os.makedirs(Path2)

            ckp = model.module.state_dict()
            torch.save(ckp, Path2+"/checkpoint.pt")

        ############ save the results #####################
    pt_save_directory=Results_Path+'/'+'analyticsnumber'
    
    ss=os.path.isdir(pt_save_directory)
    if ss==False:
        os.makedirs(pt_save_directory)
    
   # print("here",training_stats)
    Path_3=pt_save_directory+'/'+'training_stats='+str(42)+".csv"
    torch.save(training_stats,Path_3)
    

    destroy_process_group()
    ##############
if __name__ == '__main__':
    import sys
    total_epochs=int(sys.argv[1])
    save_every=int(sys.argv[2])
    batch_size=int(sys.argv[3])
    world_size = (torch.cuda.device_count())-1
    print(world_size)
    mp.spawn(main, args=(world_size, save_every, total_epochs, batch_size), nprocs=world_size,join=True)