Torch DDP Multi-GPU gives low accuracy metric

I am trying Multi-GPU, single machine DDP training in PyTorch (CIFAR-10 + ResNet-18 setup). You can refer to the model architecture code here and the full training code here.

Within main() function, the training loop is:

for epoch in range(1, num_epochs + 1):
        # Initialize metric for metric computation, for each epoch-
        running_loss = 0.0
        running_corrects = 0.0
    
        model.train()

        # Inform DistributedSampler about current epoch-
        train_loader.sampler.set_epoch(epoch)

        # One epoch of training-
        for batch_idx, (images, labels) in enumerate(train_loader):
            images = images.to(rank)
            labels = labels.to(rank)

            # Get model predictions-
            outputs = model(images)
              
            # Compute loss-
            J = loss(outputs, labels)
                
            # Empty accumulated gradients-
            optimizer.zero_grad()
                
            # Perform backprop-
            J.backward()
                
            # Update parameters-
            optimizer.step()

            '''                
            global step
            optimizer.param_groups[0]['lr'] = custom_lr_scheduler.get_lr(step)
            step += 1
            ''' 
            
            # Compute model's performance statistics-
            running_loss += J.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            running_corrects += torch.sum(predicted == labels.data)
            
        train_loss = running_loss / len(train_dataset)
        train_acc = (running_corrects.double() / len(train_dataset)) * 100
        print(f"GPU: {rank}, epoch = {epoch}; train loss = {train_loss:.4f} & train accuracy = {train_acc:.2f}%")

The problem is that the train accuracy being computed in this way is very low (say only 7.44% on average) across 8 GPUs. But, when I obtain the saved model and test its accuracy with the following code:

def test_model_progress(model, test_loader, test_dataset):
    total = 0.0
    correct = 0.0
    running_loss_val = 0.0

    with torch.no_grad():
        with tqdm(test_loader, unit = 'batch') as tepoch:
            for images, labels in tepoch:
                tepoch.set_description(f"Validation: ")
                
                images = images.to(device)
                labels = labels.to(device)
                
                # Set model to evaluation mode-
                model.eval()
            
                # Predict using trained model-
                outputs = model(images)
                _, y_pred = torch.max(outputs, 1)
                
                # Compute validation loss-
                J_val = loss(outputs, labels)
                
                running_loss_val += J_val.item() * labels.size(0)
    
                # Total number of labels-
                total += labels.size(0)

                # Total number of correct predictions-
                correct += (y_pred == labels).sum()
                
                tepoch.set_postfix(
                    val_loss = running_loss_val / len(test_dataset),
                    val_acc = 100 * (correct.cpu().numpy() / total)
                )
            
        
    # return (running_loss_val, correct, total)
    val_loss = running_loss_val / len(test_dataset)
    val_acc = (correct / total) * 100

    return val_loss, val_acc.cpu().numpy()

test_loss, test_acc = test_model_progress(trained_model, test_loader, test_dataset)

print(f"ResNet-18 (multi-gpu DDP) test metrics; loss = {test_loss:.4f} & acc = {test_acc:.2f}%")
# ResNet-18 (multi-gpu DDP) test metrics; loss = 1.1924 & acc = 59.88%

Why is there this discrepancy? What am I missing?

If train your model without DDP, did you see something similar?

What do you mean by something similar?

On training with a single GPU, the metrics don’t need to be aggregated. And, the more worrisome point seems to be that with everything else remaining same, single GPU training gives a train accuracy of about 81%, while multi-GPU is stuck at under 60%

Why is there a huge gap in performance for multi-GPU training with DDP?

OK, sounds like training accuracy of training on a single GPU is not same as that trained on DDP. How about test accuracy of single GPU vs DDP? Are the same or not?

They are not the same: test accuracy after 50 epochs of multi-GPU training = 57.28% (approx), while for single GPU, the test accuracy = 79.75% (approx). So, there seems to be a delta of 22% (approx) which is huge!

I see… Interesting… So have you try to reduce the batch size on rank GPU when using DDP? The way DDP works is that we give different input on each GPU while the model is the same (copied) across all GPUs. We do a gradient average in BWD so I am not sure if using original batch size can affect model learning negatively.

After reading through your training code, I have a guess here.
You might want to check whether the way you set lr in the optimizer is expected or not. Because the param groups’ FQN might be changed after wrapping with DDP. The rest part looks good to me.

Can you elaborate on the optimizer part? Why will param_group’s FQN change due to DDP?

Your running_loss and dataloader are local to each GPU in the DDP setup. Calling len(dataloader) will give you the size of the local shard of the original dataset (size of the full dataset divided by the number of GPUs you are using), since the distributed data sampler shards the dataset into equal portions across GPUs. The contents of running_loss is a local variable in each separate process (rank): to obtain numerically accurate results, you will have to aggregate the running losses on the master node by calling reduce(ReduceOp.SUM) operation from the torch.distributed package.
After you do that, you can investigate if you need to change the batch size by dividing by the number of GPUs or to alter the learning rate, to achieve the same numerical result as when training on a single GPU.

TLDR; I am fairly confident the models are nearly identical, you simply forgot to reduce (exchange) local metrics across the ranks and to use the full dataset length in your accuracy calculation (len(dataloader.__dataset__) or len(dataloader) * world_size).

1 Like

I do not know if this could be relevant:

https://torchmetrics.readthedocs.io/en/stable/pages/overview.html

Metrics in Distributed Data Parallel (DDP) mode

When using metrics in Distributed Data Parallel (DDP) mode, one should be aware that DDP will add additional samples to your dataset if the size of your dataset is not equally divisible by batch_size * num_processors. The added samples will always be replicates of datapoints already in your dataset. This is done to secure an equal load for all processes. However, this has the consequence that the calculated metric value will be slightly biased towards those replicated samples, leading to a wrong result.

During training and/or validation this may not be important, however it is highly recommended when evaluating the test dataset to only run on a single gpu or use a join context in conjunction with DDP to prevent this behaviour.

1 Like

Is there any example code I can refer to?

Here is an exhaustive list of available PyTorch communication primitives. For a sample code, I took the liberty in adapting your excerpt (the idea is identical for your test script)

for epoch in range(1, num_epochs + 1):
    # Initialize metric for metric computation, for each epoch-
    running_loss = 0.0
    running_corrects = 0.0
    
    model.train()
    
    train_loader.sampler.set_epoch(epoch)

    # One epoch of training-
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(rank)
        labels = labels.to(rank)

        # Get model predictions-
        outputs = model(images)
              
        # Compute loss-
        J = loss(outputs, labels)
                
        # Empty accumulated gradients-
        optimizer.zero_grad()
                
        # Perform backprop-
        J.backward()
                
        # Update parameters-
        optimizer.step()

        '''                
        global step
        optimizer.param_groups[0]['lr'] = custom_lr_scheduler.get_lr(step)
        step += 1
        ''' 
            
        # Compute model's performance statistics on each rank
        running_loss += J.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        running_corrects += torch.sum(predicted == labels.data)

    # to globally reduce local metrics across ranks, they should be Tensors
    running_loss = torch.tensor([running_loss], device=self.rank)         
    running_corrects = torch.tensor([running_corrects], device=self.rank)

    if torch.cuda.is_available():
        reduce(running_loss, dst=0, op=torch.distributed.ReduceOp.SUM)
        reduce(running_corrects, dst=0, op=torch.distributed.ReduceOp.SUM)

    # will log the aggregated metrics only on the 0th GPU. Make sure "train_dataset" is of type Dataset and not DataLoader to get the size of the full dataset and not of the local shard
    if rank==0:
        train_loss = running_loss / len(train_dataset)
        train_acc = (running_corrects.double() / len(train_dataset)) * 100     
        print(f"GPU: {rank}, epoch = {epoch}; train loss = {train_loss:.4f} & train accuracy = {train_acc:.2f}%")
1 Like