Torch DDP Multi-GPU gives low accuracy metric

Here is an exhaustive list of available PyTorch communication primitives. For a sample code, I took the liberty in adapting your excerpt (the idea is identical for your test script)

for epoch in range(1, num_epochs + 1):
    # Initialize metric for metric computation, for each epoch-
    running_loss = 0.0
    running_corrects = 0.0
    
    model.train()
    
    train_loader.sampler.set_epoch(epoch)

    # One epoch of training-
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(rank)
        labels = labels.to(rank)

        # Get model predictions-
        outputs = model(images)
              
        # Compute loss-
        J = loss(outputs, labels)
                
        # Empty accumulated gradients-
        optimizer.zero_grad()
                
        # Perform backprop-
        J.backward()
                
        # Update parameters-
        optimizer.step()

        '''                
        global step
        optimizer.param_groups[0]['lr'] = custom_lr_scheduler.get_lr(step)
        step += 1
        ''' 
            
        # Compute model's performance statistics on each rank
        running_loss += J.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        running_corrects += torch.sum(predicted == labels.data)

    # to globally reduce local metrics across ranks, they should be Tensors
    running_loss = torch.tensor([running_loss], device=self.rank)         
    running_corrects = torch.tensor([running_corrects], device=self.rank)

    if torch.cuda.is_available():
        reduce(running_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
        reduce(running_corrects, dst=0, op=torch.distributed.ReduceOp.SUM)

    # will log the aggregated metrics only on the 0th GPU. Make sure "train_dataset" is of type Dataset and not DataLoader to get the size of the full dataset and not of the local shard
    if rank==0:
        train_loss = running_loss / len(train_dataset)
        train_acc = (running_corrects.double() / len(train_dataset)) * 100     
        print(f"GPU: {rank}, epoch = {epoch}; train loss = {train_loss:.4f} & train accuracy = {train_acc:.2f}%")
1 Like