Subprocess stuck at loading batch

Hi, all

I was trying out a very simple example to use DistributedDataParallel but the code got stuck at data loading for some reason. The code I used is pasted below in its entirety.

import os
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import datasets, transforms
from torch import nn

class Model(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Linear(7*7*32, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)

        return x

def main(rank, world_size):
    # Initialisation
    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank
    )
    # Fix random seed
    torch.manual_seed(0)
    # Initialize network
    net = Model()
    net.cuda(rank)
    # Initialize loss function
    criterion = torch.nn.CrossEntropyLoss().to(rank)
    optimizer = torch.optim.SGD(net.parameters(), 1e-4) 

    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank])
    # Prepare dataset
    trainset = datasets.MNIST('./data', train=True, download=True,
        transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    )
    # Prepare sampler
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        trainset, num_replicas=world_size, rank=rank
    )
    # Prepare dataloader
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=100, shuffle=False,
        num_workers=0, pin_memory=True, sampler=train_sampler)
 
    epoch = 0
    iteration = 0

    for _ in range(5):
        epoch += 1
        train_loader.sampler.set_epoch(epoch)

        timestamp = time.time()
        print("Rank: {}. Before dataloader".format(rank))
        for batch in train_loader:
            print("Rank: {}. Batch loaded".format(rank))
            inputs = batch[0]
            targets = batch[1]

            iteration += 1
            inputs = inputs.cuda(rank, non_blocking=True)
            targets = targets.cuda(rank, non_blocking=True)

            output = net(inputs)
            loss = criterion(output, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

if __name__ == '__main__':

    # Number of GPUs to run the experiment with
    WORLD_SIZE = 2

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "8888"
    mp.spawn(main, nprocs=WORLD_SIZE, args=(WORLD_SIZE,))

As I ran the program, the print message I got looks like this

Rank: 1. Before dataloader
Rank: 0. Before dataloader
Rank: 0. Batch loaded
Rank: 0. Batch loaded
Rank: 0. Batch loaded
Rank: 0. Batch loaded
Rank: 0. Batch loaded
Rank: 0. Batch loaded
Rank: 0. Batch loaded
Rank: 0. Batch loaded
Rank: 0. Batch loaded
Rank: 0. Batch loaded
Rank: 0. Batch loaded
Rank: 0. Batch loaded

And from the GPU utilisation, I noticed that the first GPU (corresponding to subprocess with rank 0) is at its full capacity (100%) while the second one is at 0%. And more interestingly, the second subprocess (with rank 1) also occupied a small amount of memory in the first GPU. I can’t seem to figure out the problem. Please let me know if you spotted anything that might help.

Many thanks,
Fred

Thanks for reporting this issue! I confirm that I can indeed reproduce this issue and have filed a bug over at https://github.com/pytorch/pytorch/issues/46259 to get more discussion on this.

Hi, Rohan

Thanks for your attention. I’ve managed to resolve the hang by adding torch.cuda.set_device(rank) before the training loop. This stopped subprocesses with rank larger than 0 from allocating memory on cuda:0 i.e. the device used for subprocess with rank 0.

Cheers,
Fred