How to correctly validate with DistributedDataParallel?

I followed the official tutorial and wrote a CIFAR-10 training with DistributedDataParallel.
The code runs on one node and two GPUs. I split the dataset into two subsets according to labels: one subset containing labels [0, 1, ..., 4] runs on GPU 0, while the rest [5, 6, ..., 9] runs on GPU 1.

However, the validation results always show poor performance.

Here is my code:

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Subset, DataLoader

from models.resnet import ResNet18


# set up process group
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12357'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.manual_seed_all(42)


def cleanup():
    dist.destroy_process_group()


def print_on_worker_0(worker, *args, **kwargs):
    if worker == 0:
        print(*args, **kwargs)


def train(model, criterion, optimizer, scheduler, train_loader, rank):
    train_loss = 0
    total = 0
    correct = 0
    model.train()
    with torch.enable_grad():
        for iteration, (inputs, labels) in enumerate(train_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = inputs.to(rank), labels.to(rank)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # loss
            train_loss += loss.item() * labels.size(0)
            predicted = outputs.argmax(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if iteration == (len(train_loader) - 1):
                mean_train_loss = train_loss / total
                accuracy = 100. * correct / total
                print(f"Rank: {rank}, label: {labels[-10:]}, predicted: {predicted[-10:]}")
                print(f'Rank: {rank}, training loss: {mean_train_loss :.4f}, accuracy: {accuracy :.4f}\n')
    scheduler.step()


def validate(model, criterion, valid_loader, rank):
    valid_loss = 0
    total = 0
    correct = 0
    model.eval()
    with torch.no_grad():
        for iteration, (inputs, labels) in enumerate(valid_loader):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = inputs.to(rank), labels.to(rank)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # loss
            valid_loss += loss.item() * labels.size(0)
            predicted = outputs.argmax(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if iteration == (len(valid_loader) - 1):
                mean_valid_loss = valid_loss / total
                accuracy = 100. * correct / total
                print(f"Rank: {rank}, label: {labels[-10:]}, predicted: {predicted[-10:]}")
                print(f'Rank: {rank}, validation loss: {mean_valid_loss:.4f}, accuracy: {accuracy:.4f}\n')


def main(rank, world_size):
    print(f"Running Exclusive Class DDP example on rank {rank}.")
    setup(rank, world_size)

    # data process
    batch_size = 32

    transform_train = T.Compose([
        T.RandomCrop(32, padding=4, padding_mode='reflect'),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_set = torchvision.datasets.CIFAR10(root='/datasets/CIFAR10', train=True,
                                             download=False, transform=transform_train)
    all_train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False, num_workers=10)

    labels = torch.tensor(train_set.targets)
    n_labels_per_worker = 10 // world_size
    indices_list = [torch.where(labels == (rank * n_labels_per_worker + i))[0]
                    for i in range(n_labels_per_worker)]
    subset_indices = torch.concat(indices_list)

    print(f"rank: {rank}, subset_len: {len(subset_indices)}, subset_indices: {subset_indices[10:20]}")

    train_subset = Subset(train_set, subset_indices)
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=False, num_workers=10)

    model = ResNet18().to(rank)
    model = DDP(model, device_ids=[rank], output_device=rank)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    for epoch in range(1):
        print_on_worker_0(rank, f'Epoch: {epoch}')
        train(model, criterion, optimizer, scheduler, train_loader, rank)
        validate(model, criterion, train_loader, rank)

    cleanup()


if __name__ == '__main__':
    world_size = 2
    mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)

And the ResNet implementation is copied from pytorch-cifar/resnet.py at master · kuangliu/pytorch-cifar · GitHub

Note that the code is a little bit strange for debugging: the number of epochs is set to 1, set the shuffle to False on train_loader, and the data loader for validation equals to that for training.

Output:

Running Exclusive Class DDP example on rank 1.
Running Exclusive Class DDP example on rank 0.
rank: 0, subset_len: 25000, subset_indices: tensor([179, 185, 189, 199, 213, 220, 223, 233, 264, 276])
rank: 1, subset_len: 25000, subset_indices: tensor([156, 157, 167, 173, 177, 182, 183, 195, 198, 215])
Epoch: 0
Rank: 1, label: tensor([9, 9, 9, 9, 9, 9, 9, 9], device='cuda:1'), predicted: tensor([9, 9, 9, 9, 9, 9, 9, 9], device='cuda:1')
Rank: 1, training loss: 0.7736, accuracy: 71.8760

Rank: 0, label: tensor([4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0'), predicted: tensor([9, 9, 9, 9, 9, 9, 9, 9], device='cuda:0')
Rank: 0, training loss: 0.7691, accuracy: 71.3760

Rank: 1, label: tensor([9, 9, 9, 9, 9, 9, 9, 9], device='cuda:1'), predicted: tensor([4, 4, 4, 4, 4, 4, 4, 4], device='cuda:1')
Rank: 1, validation loss: 6.8682, accuracy: 0.0000

Rank: 0, label: tensor([4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0'), predicted: tensor([4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')
Rank: 0, validation loss: 5.7867, accuracy: 20.0000

It is expected that the validation accuracy should be closed to the training, and the prediction results should be closed to the targets. However, the accuracy is less than or equal to 20%. It seems that the computation goes wrong.

I tried the extreme scheme that the validation is the same as the training, it worked. As I remove the loss.backward() or the optimizer.step(), the performance will drop to 20%. It seems a synchronization issue. I also checked the official example: examples/main.py at master · pytorch/examples · GitHub, but there is no synchronization between training and validation.

Do I miss something or concepts? Does someone have any suggestions? Thanks!

would you please try to call “model.train()” before training and call “model.eval()” before validation, and see how it goes?

Thanks for your advice. As shown in the above code snippet, the result was just run with the setting (called model.train() before training and model.eval() before validation).

The solution is that call the SyncBatchNorm instead of the BatchNorm in multi-GPU training. More precisely, we use the convert_sync_batchnorm() method to convert.
https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#torch.nn.SyncBatchNorm.convert_sync_batchnorm

The phenomenon may be caused by the BatchNorm statistics being computed within each GPU, whereas the statistics largely differ from other GPUs in the Non-IID context.

1 Like