Validation with DDP

Hey there, I’m attempting to run both training and validation on the same node with 4 GPUs, using the nccl backend. However, I’m encountering a deadlock in the validation function. Just after the batch size is passed, my code hangs out in all_reduce and doesn’t continue from there. It’s worth noting that my code works fine with training on 4 GPUs validation on just 1 GPU

#this is part of class TainerDDP:

 def train(self):
    max_epochs= self.config.N_EPOCHS
    last_epoch = 0
    if self.config.LOAD_CHECKPOINT:
        self.model.load_state_dict(torch.load(self.config.RESUME_CHECK))
        last_epoch = self.config.LAST_EPOCH
        self.best_model= torch.load(self.config.BEST_MODEL_PATH)

    best_loss = 1e6
    if self.gpu_id==0:
        print("training starts")
        wandb.watch(self.model,self.compute_loss, log="all", log_freq=1)


    self.model.train()
    for epoch in range(last_epoch,max_epochs):
        self.trainloader.sampler.set_epoch(epoch)
        #run epoch
        running_loss, running_score = AverageMeter(), AverageMeter()
        #batch loop
        if self.gpu_id == 0:
            #update progress bar

        for batch_idx,(inputs, targets) in enumerate(self.trainloader):
                # Send inputs and targets to GPU memory if device is CUDA
                inputs = inputs.cuda(self.gpu_id)
                targets= targets.cuda(self.gpu_id),
    
                # Runs the forward pass under autocast
                with torch.cuda.amp.autocast(enabled=self.config.USE_AMP):
                    # forward: predict output
                    y = self.predict(inputs, phase='train')

                    # Compute loss
                    loss = self.compute_loss(y)

                # Runs the backward pass
                # Init gradient
                self.optimizer.zero_grad()
                # Backpropagation
                self.scaler.scale(loss).backward()
                # Update parameters
                self.scaler.step(self.optimizer)
                # Updates the scale for the next iteration.
                self.scaler.update()
                # Compute score
                score = get_score(targets,y)

                # Update evaluation metrics for each batch
                running_loss.update(loss.item(), inputs.size(0))
                running_score.update(score, inputs.size(0))

                if self.gpu_id == 0:
                     # Update and display progress bar
        
        # Update learning rate at the end of epoch loop (train)
        if self.config.SCHEDULER == 'OnPlateau':
            self.scheduler.step(running_loss.get_avg())
        else:
            self.scheduler.step()

         
        #VALIDATION
        dist.barrier()
        val_loss, val_score = self.validate(epoch)
        dist.barrier()

        # only save once on master gpu
        if self.gpu_id == 0:
            #SAVE BEST MODEL
            if val_loss < best_loss:
                print(f"New best model at epoch {epoch+1}")
                self.best_model = copy.deepcopy(self.model.state_dict())
                self._save_best_model(self.best_model)
                best_loss = val_loss
                
            # log metrics to wandb
            wandb.log({"train score": running_score.get_avg(), " train loss": 
            running_loss.get_avg(),"epoch":epoch})
            wandb.log({"val_score": val_score, "val_loss": val_loss, "epoch":epoch})
           
            #SAVE CHECKPOINT
            self._save_checkpoint(epoch)
            dist.barrier()
        #LOAD BEST MODEL to ALL GPUS    
        self.model.load_state_dict(self._load_best_model())
        dist.barrier()
    
    if self.gpu_id ==0:
        # SAVE MODEL AS ONNX AND TERMINATE WANDB
    return self.model


def validate(self,epoch):
    # Record loss during validation for each epoch
    #run epoch
    running_loss, running_score = AverageMeter(), AverageMeter()
    if self.gpu_id == 0:
    		#SETUP PROGRESS BAR
    self.model.eval()
    #batch loop	
    for batch_idx, (inputs, targets) in enumerate(self.validloader):
                # Send inputs and targets to GPU memory if device is CUDA
                device = torch.device('cuda', self.gpu_id)
                inputs = inputs.to(device)
                targets = targets.to(device)

                # Runs the forward pass under autocast
                with torch.cuda.amp.autocast(enabled=self.config.USE_AMP):
                    # forward: predict output
                    y = self.predict(inputs, phase='valid')

                    # Compute loss
                    loss = self.compute_loss( targets,y)

                # Compute score and error

                score = get_score(targets,y)

                # Update evaluation metrics for each batch
                running_loss.update(loss.item(), inputs.size(0))
                running_score.update(score, inputs.size(0))
               
                if self.gpu_id == 0:
                    # Update progress bar
    
    print("validation done")
    
    # my program does not continue from here    <-----

    dist.barrier()         
    # Store loss and score for printing for each epoch
    try:
        running_loss.all_reduce()
    except Exception as e:
        print("Failed to perform all reduce for loss",e)
    try:
        running_score.all_reduce()
    except Exception as e:
        print("Failed to perform all reduce for score",e)

    print("performed all reduce")

    
    val_loss=running_loss.get_avg()
    score= running_score.get_avg()
    

    return val_loss, score

class AverageMeter(object):

def __init__(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0
def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count
def get_avg(self):
    return self.avg
def all_reduce(self):
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
    dist.all_reduce(total, dist.ReduceOp.SUM)
    self.sum, self.count = total.tolist()
    self.avg = self.sum / self.count

def ddp_setup(rank: int, world_size: int):
os.environ[“MASTER_ADDR”] = “localhost”
os.environ[“MASTER_PORT”] = “12355” # select any idle port on the machine
init_process_group(backend=“nccl”, rank=rank, world_size=world_size)