Distributed training with DDP hangs

I am attempting to use DistributedDataParallel for single-node, multi-GPU training in a SageMaker Studio multi-GPU instance environment, within a Docker container. My entry code is as follows:

import os
from PIL import ImageFile
import torch.multiprocessing as mp

nodes, gpus = 1, 4
world_size = nodes * gpus

# set environment variables for distributed training
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

# workaround for an issue with the data
ImageFile.LOAD_TRUNCATED_IMAGES = True

# a PyTorch Dataset object which loads pairs of images for contrastive learning
# I have tried both pre-loading and lazy-loading of images, same results either way
train_ds = get_training_dataset()

# this is an nn.Module which is not (yet) wrapped in DistributedDataParallel
# model is relatively complex but does not seem to be related to the issue
# so I have cut excess code here. if necessary I can provide some more detail
model = get_model()

# "fork" as I read it works better in notebooks; "spawn" was giving an error
# "gloo" over "nccl" as the backend because NCCL was giving an error
# "veth-app0-2" as interface name was obtained from running 'ifconfig'
mp.start_processes(
    training_worker_func,
    nprocs=gpus,
    args=(nr, gpus, world_size, train_ds, model, epochs, lr, momentum,
          weight_decay, batch_size, "gloo", "veth-app0-2"),
    join=True,
    start_method="fork"
)

The training function is then:

import os
import torch
import torch.distributed as dist
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

def training_worker_func(gpu, nr, gpus, world_size, train_ds, model, epochs, lr, momentum, 
                         weight_decay, batch_size, distributed_backend, network_ifname):
    # compute overall rank from node rank, GPUs, and current GPU
    rank = nr * gpus + gpu

    # set network interface environment variable
    # without this, a warning shows about being unable to detect interface and using a fallback
    # but the results are the same other than the warning
    os.environ[distributed_backend.upper() + "_SOCKET_IFNAME"] = network_ifname

    # initialise the process group for distributed processing
    dist.init_process_group(
    	backend=distributed_backend,
        init_method="env://",
    	world_size=world_size,
    	rank=rank,
    )

    # set CUDA to the correct GPU
    torch.cuda.set_device(gpu)

    # switch model to CUDA
    # I have also tried deep-copying the model before calling cuda(), same result
    model.cuda(gpu)

    # convert model to DistributedDataParallel
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

    # initialise loss function
    criterion = nn.CrossEntropyLoss().cuda(gpu)

    # initialise SGD as the optimiser
    opt = SGD(model.parameters(), lr, momentum=momentum, weight_decay=weight_decay)

    # use a distributed sampler for training
    sampler = DistributedSampler(train_ds, num_replicas=world_size, rank=rank)

    # create the training DataLoader using the sampler
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=False, num_workers=0,
                                        pin_memory=True, sampler=sampler)

    # begin the training process
    for epoch in range(epochs):
        loss_weights = [0.1, 0.4, 0.7, 1.0]
        sum_total_loss = 0.0

        # switch to train mode
        model.train()

        print(f"GPU {gpu} beginning epoch {epoch} of {epochs}...")  # this prints for all 4 processes
        for j, images in enumerate(train_dl):
            print(f"GPU {gpu}, epoch {epoch}: {j} / {len(train_dl)}")  # this never prints for any process

            images[0] = images[0].cuda(non_blocking=True)
            images[1] = images[1].cuda(non_blocking=True)

            # compute output
            output, target = model(im_q=images[0], im_k=images[1])

            # compute loss
            losses = [criterion(output[:,i,:], target[:,i]) for i in range(12)]
            total_loss = sum(l * loss_weights[i%4] for i, l in enumerate(losses))
            sum_total_loss += total_loss

            # compute gradient and do SGD step
            opt.zero_grad()
            total_loss.backward()
            opt.step()

        if gpu == 0:
            # only print loss if this is the main worker
            print(f"Epoch {epoch + 1} training loss: {sum_total_loss / len(train_dl)}")

    if gpu == 0:
        # only print if this is the main worker
        print("Training complete!")

The result of this code is that the 4 processes are spawned without error, the message “GPU … beginning epoch 0 of 1…” prints for all 4 processes, but then none of the 4 processes reach the next print statement. There is no error, the process just hangs.

The images all seem to be successfully loaded into the instance memory, as a good chunk of its memory is utilised. I don’t think the problem is data-specific as I can iterate through the same data locally no problem.

So it seems like there is some issue in the distributed processes where the data loaded onto the CPU is not being shared? I am new to the distributed training field so I am unsure exactly how this works internally. Any help or guidance would be appreciated.

Any help on this would be much appreciated!

The code generally looks good, I have concerns for the fork argument in start_processes. The CUDA runtime is not supported in forked subprocesses. I would try replacing it with spawn, for example:

mp.spawn(
    training_worker_func,
    nprocs=gpus,
    args=(nr, gpus, world_size, train_ds, model, epochs, lr, momentum,
          weight_decay, batch_size, "gloo", "veth-app0-2"),
    join=True,
)

As an aside: I have tried using python multiprocessing in jupyter notebooks and there were many caveats and gotchas, so I would recommend running using python scripts first and validating it works that way. You mentioned you are using a notebook and sagemaker so I’m not sure if there are similar issues.

1 Like

Thanks for the input! I have just tried running the code with spawn over fork again. The result is an error which reads:

Bus error (core dumped)

This error pops up at the same time as the process was hanging when using fork - after initialising the process group, when attempting to load data.

There is a core.xxx file as well, which is 14GB. I have read that this bus error is usually a result of insufficient shared memory, but the Docker container my instance is (or at least, should be) running inside was built with 16GB of shared memory specified, which I would presume to be sufficient.

As an aside, my reason for trying fork originally was due to a comment in the PyTorch code which says:

in environments like Ipython notebooks, ‘fork’ works better than ‘spawn’

But I am not sure if this is relevant. You may well be right that the problem is related to the notebook environment - even if it is not the problem directly, it is definitely making it quite hard to debug issues so I might consider alternatives.

cc @VitalyFedyunin @ejguan For any guidance on using data loaders with multiprocessing and notebooks. It seems to be hanging when retrieving a batch from the data loader.

1 Like

If it’s running on windows, I don’t know if there is a way to run mp in notebook on windows.

You are running multiprocessing here, which potentially cause the memory usage up to 14GB * number of processes within DataLoader. Since we don’t have the code of your dataset, I can give you rough idea about how to reduce the memory usage.

  1. Try to load data lazily. If you are using Dataset class, try to put your code loading data into __getitem__ function
  2. Try to use IterableDataset or using torchdata to create your dataset. By nature, this type of class would use streaming style data loading method.

Thanks for the tips. I did originally utilise lazy loading, but I have tweaked the code since then, so I will revert back to lazy loading and see if it makes a difference. Will also have a look into the IterableDataset or torchdata options.

This is on a SageMaker instance so running on Linux.

The loaded data can be relatively large as this is a computer vision problem, but the entire dataset is still only 1.63GB and the model is < 100MB.

The Dataset code is quite simple:

    def __init__(
        self,
        images,
        transform: Optional[Callable] = None,
        preload: bool = False,
    ):
        if transform is None:
            transform = transforms.Compose([transforms.ToTensor()])

        self.images = [open_rgb(i) for i in images] if preload else images
        self.transform = transform

    def __getitem__(self, i):
        path = self.images[i]
        image = path if isinstance(path, Image.Image) else open_rgb(path)

        if self.transform:
            image = self.transform(image)
        return image

For this use I then pass a transform which makes two random augmentations to the image and returns a pair augmented versions of the image in place of said image.

Is self.images a list of string if not preload?

Yes, self.images is a list of strings representing paths to image files in that case.

Since my last update I moved from SageMaker Studio notebook to an EC2 instance and have successfully trained an epoch of the model with the same code (preload now set to False, with spawn not fork, and using nccl instead of gloo) so it appears the notebook environment was at least part of the problem.

Thanks Howard and Erjia for the help! I have one last small issue in that the training will now not finish. If I set number of epochs to 1, going back to the training loop code, I see all of the “GPU {n} epoch {e} / {E}” messages print, but the final message with epoch training loss is never printed.

Making it weirder, if I set number of epochs to 2, I see the training loss outputted at the end of the first epoch, then all of the batches being processed for the second epoch, but never the training loss for the second epoch.

I added additional status updates and found that the opt.step() call is successfully completing at the end of the final epoch for all 4 GPUs, but at the end of the loop the process seems to hang despite completing fine for the prior epoch.

For distributed training, I believe you need to have the same amount of batches feed into your model per distributed process.

You should make sure your DistributedSampler was able to shard your dataset evenly to prevent hanging issue

1 Like

Thanks. I validated the dataset is being sharded evenly but still had the same problem.

As this only occurs at the end of the final epoch no matter how many epochs I train for, my current solution is just to train for an additional epoch, and save a checkpoint prior to the final epoch. This checkpoint then serves as the final model.

I will continue to see if I can resolve the underlying issue but I now have a working process. Thanks very much for all the help along the way!