Properly implementing DDP in training loop with cleanup, barrier, and its expected output

Hi,

I’m currently trying to figure out how to properly implement DDP with cleanup, barrier, and its expected output. While I think gives the dpp tutorial Getting Started with Distributed Data Parallel — PyTorch Tutorials 1.11.0+cu102 documentation gives a great initial example on how to do this, I’m having some trouble translating that example to something more illustrative. I’ve chosen to translate the PyTorch CIFAR example Training a Classifier — PyTorch Tutorials 1.11.0+cu102 documentation into the form of the above mentioned ddp tutorial.

For context: I’m currently running on a single node with 4 GPUs:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M|
| Fan  Temp  Perf  Pwr:Usage/Cap|
|                               |
|===============================+
|   0  Tesla V100S-PCI...  Off  |
| N/A   35C    P0    37W / 250W |
|                               |
+-------------------------------+
|   1  Tesla V100S-PCI...  Off  |
| N/A   33C    P0    36W / 250W |
|                               |
+-------------------------------+
|   2  Tesla V100S-PCI...  Off  |
| N/A   32C    P0    35W / 250W |
|                               |
+-------------------------------+
|   3  Tesla V100S-PCI...  Off  |
| N/A   32C    P0    37W / 250W |
|                               |
+-------------------------------+

The code I’m implementing can be found below (with many of the comments from the original CIFAR example removed for brevity of this post):

import os
import sys
import tempfile
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
  os.environ['MASTER_ADDR'] = 'localhost'
  os.environ['MASTER_PORT'] = '12355'

  dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
  dist.destroy_process_group()

def run_demo(demo_fn, world_size):
  mp.spawn(demo_fn,
    args=(world_size,),
    nprocs=world_size,
    join=True)

def imshow(img):
  img = img / 2 + 0.5     # unnormalize
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.show()

class Net(nn.Module):
    def __init__(self):
      super().__init__()
      self.conv1 = nn.Conv2d(3, 6, 5)
      self.pool = nn.MaxPool2d(2, 2)
      self.conv2 = nn.Conv2d(6, 16, 5)
      self.fc1 = nn.Linear(16 * 5 * 5, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = torch.flatten(x, 1)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x

def generate_data():

  transform = transforms.Compose(
      [transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  batch_size = 4

  trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True, transform=transform)

  trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True)#, num_workers=2)

  testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)

  testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False)#, num_workers=2)

  classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  
  #dataiter = iter(trainloader)
  #images, labels = dataiter.next()

  #imshow(torchvision.utils.make_grid(images))
  # print labels
  #print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
  return trainset, trainloader, testset, testloader, classes

def demo_basic(rank, world_size):

  print(f"Running basic DDP example on rank {rank}.")
  setup(rank, world_size)

  trainset, trainloader, testset, testloader, classes = generate_data()

  net = Net().to(rank)

  ddp_model = DDP(net, device_ids=[rank])

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)

  for epoch in range(2):

    running_loss = 0.0
    for i, data in enumerate(trainloader):

      inputs, labels = data

      optimizer.zero_grad()

      labels = labels.to(rank)

      outputs = ddp_model(inputs.to(rank))
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      running_loss += loss.item()

      # Should cleanup be here or below?
      #cleanup()
      
      if i % 2000 == 1999:
        print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
        running_loss = 0.0
        break
        
  # Should cleanup be here or above?
  cleanup()

  print('Finished Training')

def demo_checkpoint(rank, world_size):

  print(f"Running checkpoint DDP example on rank {rank}.")
  setup(rank, world_size)

  trainset, trainloader, testset, testloader, classes = generate_data()

  net = Net().to(rank)

  ddp_model = DDP(net, device_ids=[rank])

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)

  CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
  if rank == 0:
    # All processes should see same parameters as they all start from same
    # random parameters and gradients are synchronized in backward passes.
    # Therefore, saving it in one process is sufficient.
    torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # Use a barrier() to make sure that process 1 loads the model after process
    # 0 saves it.
    dist.barrier()
    # configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location))

  for epoch in range(2):

    running_loss = 0.0
    for i, data in enumerate(trainloader):

      inputs, labels = data

      optimizer.zero_grad()

      labels = labels.to(rank)

      outputs = ddp_model(inputs.to(rank))
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      running_loss += loss.item()

      # Should cleanup be here or below?
      #cleanup()
      
      if i % 2000 == 1999:
        print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
        running_loss = 0.0
        break
  
    # Should the model be saved before or after cleanup?
    #
    #Saving the model after each epoch
    if rank == 0:
      # All processes should see same parameters as they all start from same
      # random parameters and gradients are synchronized in backward passes.
      # Therefore, saving it in one process is sufficient.
      torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

      # Should cleanup be here or above?
      #cleanup()
  
  # Now that we are done with training, do we now cleanup the processes again?
  if rank == 0:
    os.remove(CHECKPOINT_PATH)

  cleanup()

  print('Finished Training')

  #PATH = './cifar_net.pth'
  #torch.save(net.state_dict(), PATH)

  dataiter = iter(testloader)
  images, labels = dataiter.next()

  #imshow(torchvision.utils.make_grid(images))
  #print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

  # Reloading the model to the rank and to DDP to since we have already
  # executed "cleanup" on the model
  # create local model
  net = Net().to(rank)

  ddp_model = DDP(net, device_ids=[rank])

  # Loading our model since we have redefined our 
  # configure map_location properly
  map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
  ddp_model.load_state_dict(
  torch.load(CHECKPOINT_PATH, map_location=map_location))

  outputs = ddp_model(images)

  _, predicted = torch.max(outputs, 1)

  print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                                for j in range(4)))

  correct = 0
  total = 0

  with torch.no_grad():
    for data in testloader:
      images, labels = data

      labels = labels.to(rank)

      outputs = ddp_model(images)

      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

  correct_pred = {classname: 0 for classname in classes}
  total_pred = {classname: 0 for classname in classes}

  with torch.no_grad():

    for data in testloader:
      images, labels = data

      labels = labels.to(rank)

      outputs = ddp_model(images)
      _, predictions = torch.max(outputs, 1)

      for label, prediction in zip(labels, predictions):

        labels = labels.to(rank)
        if label == prediction:
          correct_pred[classes[label]] += 1
        total_pred[classes[label]] += 1

  for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

class NetMP(nn.Module):
    def __init__(self, dev0, dev1, dev2, dev3):
      super(NetMP, self).__init__()
      #super().__init__()
      self.dev0 = dev0
      self.dev1 = dev1
      self.dev2 = dev2
      self.dev3 = dev3
      self.conv1 = nn.Conv2d(3, 6, 5).to(dev0)
      self.pool = nn.MaxPool2d(2, 2).to(dev1)
      self.conv2 = nn.Conv2d(6, 16, 5).to(dev2)
      self.fc1 = nn.Linear(16 * 5 * 5, 120).to(dev3)
      self.fc2 = nn.Linear(120, 84).to(dev1)
      self.fc3 = nn.Linear(84, 10).to(dev2)

    def forward(self, x):
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = torch.flatten(x, 1) # flatten all dimensions except batch
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x

def demo_model_parallel(rank, world_size):
  # For the sake of reducing redundncy let's implement demo_model_parallel as a
  # version of demo_basic, not demo_checkpoint
  print(f"Running DDP with model parallel example on rank {rank}.")
  setup(rank, world_size)

  dev0 = ((rank * 4) % 4)
  dev1 = ((rank * 4 + 1) % 4)
  dev2 = ((rank * 4 + 2) % 4)
  dev3 = ((rank * 4 + 3) % 4)

  trainset, trainloader, testset, testloader, classes = generate_data()

  net = NetMP(dev0, dev1, dev2, dev3)

  ddp_model = DDP(net)

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)

  for epoch in range(2):

    running_loss = 0.0
    for i, data in enumerate(trainloader):

      inputs, labels = data

      optimizer.zero_grad()

      labels = labels.to(dev1)

      outputs = ddp_model(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      running_loss += loss.item()

      # Should cleanup be here or below?
      #cleanup()
      
      if i % 2000 == 1999:
        print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
        running_loss = 0.0
        break
        
      # Should cleanup be here or above?
      cleanup()

  print('Finished Training')

if __name__=="__main__":
  n_gpus = torch.cuda.device_count()
  assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
  world_size = n_gpus
  run_demo(demo_basic, world_size)
  run_demo(demo_checkpoint, world_size)
  run_demo(demo_model_parallel, world_size)

To simplify our work, let’s first make the small adjustment of:

  run_demo(demo_basic, world_size)
  #run_demo(demo_checkpoint, world_size)
  #run_demo(demo_model_parallel, world_size)

Keeping an eye on our GPUs we see:

+-------------------------------------------------------------------------+
| Processes:                                                              |
|  GPU   GI   CI        PID   Type   Process name              GPU Memory |
|        ID   ID                                               Usage      |
|=========================================================================|
|    0   N/A  N/A   0000001	 C                                     307MiB |
|    0   N/A  N/A   0000001	 C   ...onda3/envs/env/bin/python      917MiB |
|    0   N/A  N/A   0000001	 C   ...onda3/envs/env/bin/python      905MiB |
|    0   N/A  N/A   0000001	 C   ...onda3/envs/env/bin/python      905MiB |
|    0   N/A  N/A   0000001	 C   ...onda3/envs/env/bin/python      905MiB |
|    1   N/A  N/A   0000001	 C                                     307MiB |
|    1   N/A  N/A   0000001	 C   ...onda3/envs/env/bin/python     1403MiB |
|    2   N/A  N/A   0000001	 C                                     307MiB |
|    2   N/A  N/A   0000001	 C   ...onda3/envs/env/bin/python     1403MiB |
|    3   N/A  N/A   0000001	 C                                     307MiB |
|    3   N/A  N/A   0000001	 C   ...onda3/envs/env/bin/python     1403MiB |
+-------------------------------------------------------------------------+

and all values in our Volatile GPU-Util are less than 100%. Now our terminal output looks like:

Running basic DDP example on rank 0.
Running basic DDP example on rank 1.
Running basic DDP example on rank 2.
Running basic DDP example on rank 3.
[1,  2000] loss: 2.181
[1,  2000] loss: 2.187
[1,  2000] loss: 2.180
[1,  2000] loss: 2.187
[2,  2000] loss: 1.740
[2,  2000] loss: 1.737
[2,  2000] loss: 1.739
Finished Training
[2,  2000] loss: 1.737
Finished Training
Finished Training
Finished Training

Here is my first issue- if things are being distributed and run in parallel, then shouldn’t we have a single loss per epoch rather than 4 losses per epoch? From the output it appears that instead of sending the forward and backprop work from rank 0 to the other ranks and then having rank 0 collect that result and deliver a single result to the user (us), what’s going on instead is that each rank is running it’s own forward pass and backprop and then sending it out to the user; in effect while the work is parallel, the ranks are not working together but rather independently. This behavior in fact will not result in a speedup but will be only as slow as the slowest rank execution.

Is my interpretation of what’s going on correct or am I missing something? Also is this expected behavior on how to use DDP in this instance or is my code incorrect? Any and all thoughts are welcome.

Let’s go ahead and run the same code with the small change of:

  #run_demo(demo_basic, world_size)
  run_demo(demo_checkpoint, world_size)
  #run_demo(demo_model_parallel, world_size)

This shoots our Volatile GPU-Util on all of our GPUs to be 100%. Our terminal output now stalls at

Running checkpoint DDP example on rank 3.
Running checkpoint DDP example on rank 1.
Running checkpoint DDP example on rank 0.
Running checkpoint DDP example on rank 2.

At this point I have to manually stop the code because it completely stalls, yet no error shows up. After a little debugging I find that rank 0 does not pass the line

dist.barrier()

I’m not sure how to debug or fix this, especially given that it’s part of the original DDP example (and its function and placement makes sense)… Any thoughts?

Let’s go ahead and run the same code with the small change of:

  #run_demo(demo_basic, world_size)
  #run_demo(demo_checkpoint, world_size)
  run_demo(demo_model_parallel, world_size)

after which I am greeted with this error:

Traceback (most recent call last):
  File "/user/classical_parallel_compute.py", line 425, in <module>
    run_demo(demo_model_parallel, world_size)
  File "/user/classical_parallel_compute.py", line 53, in run_demo
    mp.spawn(demo_fn,
  File "/user/miniconda3/envs/env/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/user/miniconda3/envs/env/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/user/miniconda3/envs/env/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 2 terminated with the following error:
Traceback (most recent call last):
  File "/user/miniconda3/envs/env/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/user/classical_parallel_compute.py", line 378, in demo_model_parallel
    ddp_model = DDP(net)#, device_ids=[rank])
  File "/user/miniconda3/envs/env/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 578, in __init__
    dist._verify_model_across_ranks(self.process_group, parameters)
RuntimeError: NCCL error in: /user/conda/feedstock_root/build_artifacts/pytorch-recipe_1640869844479/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:957, invalid usage, NCCL version 21.1.4
ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).

Running

NCCL_DEBUG=INFO python classical_parallel_compute.py

My output gave me an error that looked something like

NCCL WARN Duplicate GPU detected : rank 0 and rank 1 both on CUDA device 3...
NCCL WARN Duplicate GPU detected : rank 3 and rank 0 both on CUDA device 3....

another post solves this problem via

net.to(f'cuda:{args.local_rank}')

However you’ll notice that in the Model Parallelism example for DDP the model itself was sent to a rank and not a CUDA device, so therefore I am hesitant to say that the above solution applies.

I’m not sure what’s going wrong or what needs fixing. I suspect that I’m using cleanup() incorrectly in my attempt to accumulate the behavior of the GPUs on my system, though that may not be the problem at all…

Any suggestions to code changes or advice as to what to fix would be very much appreciated.

Regarding speedup, you are correct that the bottleneck would be the slowest rank. However, DDP aims to speed up training by distributing the original dataset across workers, so for example, if you have to process 1k inputs for your training, DDP with 4 workers will instead have each GPU process only 250 inputs, leading to a linear 4x speedup. Although in practice, it is not linear due to the communication overhead.

Regarding the following two issues, thank you for providing a reproduction script! I’ve filed a GItHub issue with this repro: Investigate and update DDP tutorials · Issue #74246 · pytorch/pytorch · GitHub, mentioning that we need to investigate and fix our tutorials accordingly.

With regard to your final comment, net.to(f'cuda:{args.local_rank}') makes sense that it would fix the issue, as DDP assumes that every process is using a distinct GPU. The issue also tracks updating our tutorials to ensure that we do this.

@rvarm1 - thanks for opening up the GitHub issue on this!

Before I forget, the solution of net.to(f'cuda:{args.local_rank}') can be found here in the event that the context of the solution between posts becomes relevant.

Regarding your response to your 1st paragraph (“Regarding speedup…not linear due to the communication overhead”), I’m still a little confused; I understand that DDP aims “to speed up training by distributing the original dataset across workers,” but I’m still uncertain if the following output is expected:

Running basic DDP example on rank 0.
Running basic DDP example on rank 1.
Running basic DDP example on rank 2.
Running basic DDP example on rank 3.
[1,  2000] loss: 2.181
[1,  2000] loss: 2.187
[1,  2000] loss: 2.180
[1,  2000] loss: 2.187
[2,  2000] loss: 1.740
[2,  2000] loss: 1.737
[2,  2000] loss: 1.739
Finished Training
[2,  2000] loss: 1.737
Finished Training
Finished Training
Finished Training

While I understand that each GPU is handling it’s own input (and therefore output), does this indeed mean that we have 4 separate losses? If so, how does PyTorch handle each loss and form a “final loss” after each minibatch?

Also I know that PyTorch averages the loss across each minibatch, but we’re not dealing with minibatches here (and in such a scenario each GPU would calculate its own minibatch), we’re dealing with our inputs that are evenly distributed across the GPUs (so therefore data from one minibatch may be placed on 2 different GPUs). Does DPP automatically take care of this instance where data from one minibatch may be placed on 2 different GPUs and therefore the iterations and resulting loss are communicated and calculated correctly? Or Is my understanding incorrect?

Thanks @dmack for trying out DDP! Here is my understanding:

  • One way to think about data parallel training is that it increases the effective batch size. If each worker in a world of size W operates on a batch size B, then the effective batch size is W * B.
  • DDP computes the loss on each worker according to each worker’s defined loss function. This loss function may sum over the samples (i.e. reduction="sum") or average over the samples (i.e. reduction="mean").
  • In the latter case where the loss function averages over the samples, each worker computes loss = (1 / B) * sum_{b=1}^{B} loss_fn(output[i], label[i]) as the loss for each batch of size B. DDP schedules an all-reduce so that each worker sums these losses and then divides by the world size W.
    This computes (1 / W) * sum_{r=1}^{W} (1 / B) * sum_{b=1}^{B} loss_fn(output[i], label[i]), which is equal to (1 / (W * B)) * sum_{r=1}^{W} sum_{b=1^{B} loss_fn)(output[i], label[i]) (where i represents the current sample index for the batch). Thus, we see that DDP is training with an effective batch size of W * B when the loss function uses reduction="mean".
    Aside from the communication overhead, these W * B samples can be processed in about the same time as B samples on a single process without DDP. This is the gain from using data parallel training.

Now, let me turn to your example script. I will assume for now there is no early break at 2000 samples.

  • In your script, each worker processes the entire training set. Given the earlier discussion, this is equivalent to training using a single worker using a dataset that is the concatenation of W copies of the original dataset and using a batch size of W * B.
  • In the normal DDP usage, we may use DistributedSampler so that each worker only receives a 1 / W fraction of the dataset. Notice how then each worker processes 1 / W of the dataset, while in your version, each worker processes the full dataset.

If we call your version DMACK and the single-worker equivalent DMACK_SINGLE and suppose that they have the following configurations:

  • DMACK: batch size B; number of epochs E; dataset D
  • DMACK_SINGLE: batch size B * W; number of epochs E; dataset D replicated W times

then, DMACK and DMACK_SINGLE are equivalent.

If we call the normal DDP usage DDP and the single-worker equivalent DDP_SINGLE and suppose they have the following configurations:

  • DDP: batch size B / W; number of epochs E * W; dataset D with DistributedSampler giving 1 / W fraction to each worker
  • DDP_SINGLE: batch size B; number of epochs E; dataset D

then, DDP and DDP_SINGLE are equivalent.
Notice that wrapping with DDP in both cases means the dataset size and batch size can be decreased by a factor 1 / W per worker.


I have written a script to compare these four versions. Suppose B = 64, E = 2, and W = 4.
For DDP and DDP_SINGLE, I do not use E * W epochs (as described above) but rather E * W * f, where f = 0.66 is a discount factor since the speedup is nonlinear due to communication overhead. After rounding, this yields E * W * f = 4. The point is to have DDP and DMACK take approximately the same amount of time to run.
On my AWS instance, I see the following results (using the full CIFAR10 training set):

  • DMACK: batch size 64; 2 epochs; training loss ~1.94; elapsed time 21 seconds
  • DMACK_SINGLE: batch size 256; 2 epochs; training loss ~1.99; elapsed time 62 seconds
  • DDP: batch size 16; 4 epochs: training loss ~1.77; elapsed time 17 seconds
  • DDP_SINGLE: batch size 64; 4 epochs; training loss ~1.76; elapsed time 36 seconds

The losses will have some variance from the random shuffling, but we see that the multi- and single-worker versions have approximately the same loss, as expected.

The takeaway is that the normal DDP usage allows us to train faster since each worker uses a smaller per-worker batch size. We see that the DDP version runs 4 epochs in less time than DMACK runs 2 epochs. (However, the speedup is never truly linear due to fixed and communication overheads.)


Finally, let me try to answer your questions in case they still are unclear after reading the above:

does this indeed mean that we have 4 separate losses? If so, how does PyTorch handle each loss and form a “final loss” after each minibatch?

Yes, each worker has its own loss. When using a loss function that averages over samples, there is no need to explicitly compute the single “final loss”. Each loss is there to compute gradients for a batch. What DDP does is equivalent to averaging each worker’s loss into a single loss, aggregating each worker’s batch into a single batch, and backpropagating the single loss through the single batch, only this is done in a distributed fashion, which yields performance benefits (faster and less memory per worker).

Also I know that PyTorch averages the loss across each minibatch, but we’re not dealing with minibatches here (and in such a scenario each GPU would calculate its own minibatch), we’re dealing with our inputs that are evenly distributed across the GPUs (so therefore data from one minibatch may be placed on 2 different GPUs). Does DPP automatically take care of this instance where data from one minibatch may be placed on 2 different GPUs and therefore the iterations and resulting loss are communicated and calculated correctly? Or Is my understanding incorrect?

The concept of minibatch is flexible. In the non-distributed single-worker context, you can think of a minibatch as the input and label tensors yielded by the dataloader on a given iteration. In the distributed data parallel context, a minibatch can refer to either the tensors yielded by a given worker’s dataloader (size B), or it can refer to the effective minibatch, which is the concatenation of each worker’s minibatch on that iteration (size W * B). When you say “dealing with our inputs that are evenly distributed across the GPUs (so therefore data from one minibatch may be placed on 2 different GPUs)”, this is referring to the effective minibatch, which is distributed across the workers. I think this confusion may be clarified by looking at the DistributedSampler. This handles the concept of effective minibatch for you since now each worker only gets 1 / W of the original dataset, and you can think again in terms of per-worker size-B minibatches.


Here is the comparison script for you to try out:

import argparse
import os
import sys
import time
from enum import Enum, auto

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import torchvision
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms


class Mode(Enum):
    DMACK = auto()         # dmack's usage
    DMACK_SINGLE = auto()  # single-process version of DMACK
    DDP = auto()           # normal DDP usage
    DDP_SINGLE = auto ()   # single-process version of DDP


class CNN(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def gen_data(
    rank: int,
    world_size: int,
    batch_size: int,
    mode: Mode,
):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    if mode == Mode.DMACK_SINGLE:
        trainset = torch.utils.data.ConcatDataset([
            torchvision.datasets.CIFAR10(
                root="./data", train=True, download=True, transform=transform,
            ) for _ in range(world_size)
        ])
    else:
        trainset = torchvision.datasets.CIFAR10(
            root="./data", train=True, download=True, transform=transform,
        )
    if mode == Mode.DDP:
        sampler = DistributedSampler(
            trainset, num_replicas=world_size, rank=rank, shuffle=False,
            drop_last=False,
        )
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, sampler=sampler,
        )
    else:
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True,
        )
    classes = {
        "plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship",
        "truck",
    }
    return trainset, trainloader, classes


def run_distributed(fn_per_rank, world_size, mode) -> None:
    """
    Args:
        fn_per_rank (Callable): Function to run on each rank.
        world_size (int): Number of ranks.
    """
    nprocs = 1 if mode == Mode.DMACK_SINGLE or mode == Mode.DDP_SINGLE \
        else world_size
    mp.spawn(
        fn_per_rank, args=(world_size, mode), nprocs=nprocs,
        join=True,
    )

def train(rank: int, world_size: int, mode: Mode):
    if mode != mode.DMACK_SINGLE and mode != mode.DDP_SINGLE:
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12355"
        dist.init_process_group("nccl", rank=rank, world_size=world_size)

    base_batch_size = 64
    if mode == mode.DDP:
        BATCH_SIZE = base_batch_size // world_size
    elif mode == mode.DMACK_SINGLE:
        BATCH_SIZE = base_batch_size * world_size
    elif mode == mode.DMACK or mode == mode.DDP_SINGLE:
        BATCH_SIZE = base_batch_size
    else:
        raise ValueError(f"Unsupported mode: {mode}")
    trainset, trainloader, classes = gen_data(
        rank, world_size, BATCH_SIZE, mode,
    )
    num_batches = len(trainloader)
    sys.stdout.flush()
    if rank == 0:
        print(
            f"[Rank={rank}] Training on {num_batches} batches with batch "
            f"size {BATCH_SIZE}"
        )
    device = torch.device(rank)
    model = CNN().to(device)
    if mode == mode.DMACK or mode == mode.DDP:
        ddp_model = DDP(model, device_ids=[rank])
    else:
        ddp_model = model
    loss_fn = torch.nn.CrossEntropyLoss()
    optim = torch.optim.SGD(ddp_model.parameters(), lr=1e-3, momentum=0.9)

    start_time = time.time()
    NUM_EPOCHS = 2
    if mode == mode.DDP or mode == mode.DDP_SINGLE:
        nonlinear_factor = 0.66  # non-linear speedup due to communication
        NUM_EPOCHS *= int(world_size * nonlinear_factor)
    INTERVAL = num_batches // 3  # print loss 3-4 times per epoch
    for epoch in range(1, 1 + NUM_EPOCHS):
        running_loss = 0.
        for i, (inputs, labels) in enumerate(trainloader, 1):
            optim.zero_grad()
            outputs = ddp_model(inputs.to(device))
            loss = loss_fn(outputs, labels.to(device))
            loss.backward()
            optim.step()
            running_loss += loss

            if i % INTERVAL == 0 and i > 0:
                avg_interval_loss = running_loss / INTERVAL
                print(
                    f"[Rank={rank} Epoch={epoch} Iter={i}] "
                    f"Loss={avg_interval_loss:.3f}"
                )
                sys.stdout.flush()
                running_loss = 0.0
    
    if mode != mode.DMACK_SINGLE and mode != mode.DDP_SINGLE:
        dist.destroy_process_group()
    elapsed_time = time.time() - start_time
    if rank == 0:
        print(f"[Rank={rank}] Finished training")
        print(f"[Rank={rank}] Elapsed time={elapsed_time:.3f} s")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", "-m", type=str, default="ddp")
    args = parser.parse_args()
    mode_str = args.mode
    if mode_str == "ddp":
        mode = Mode.DDP
    elif mode_str == "ddp_single":
        mode = Mode.DDP_SINGLE
    elif mode_str == "dmack":
        mode = Mode.DMACK
    elif mode_str == "dmack_single":
        mode = Mode.DMACK_SINGLE
    else:
        raise ValueError(f"Unsupported mode: {mode_str}")
    print(f"Running in mode: {mode_str}")

    world_size = torch.cuda.device_count()
    assert mode == Mode.DMACK_SINGLE or world_size >= 2
    run_distributed(train, world_size, mode)

1 Like

Regarding the dist.barrier() hang for the checkpointing example, you need to call dist.barrier() on all ranks because it is a collective communication, meaning that all ranks must participate synchronously.

You simply need to un-indent the lines starting from dist.barrier() to right before for epoch in range(2).

Given your script though, it will still error after in demo_checkpoint() because you try to continue running after calling cleanup(). cleanup() does not clean up the model. It cleans up the process group, which allows the processes to communicate. You should call cleanup() when you are done with the process group itself, which is usually at the end of the train/test method. Also, you should not remove the saved model state by removing CHECKPOINT_PATH since you plan to load it later for testing.


Here is a fixed version of your demo_checkpoint(). I marked places I changed with [***].

def demo_checkpoint(rank, world_size):
  print(f"Running checkpoint DDP example on rank {rank}.")
  setup(rank, world_size)

  trainset, trainloader, testset, testloader, classes = generate_data()
  net = Net().to(rank)
  ddp_model = DDP(net, device_ids=[rank])
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)

  CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
  if rank == 0:
    # All processes should see same parameters as they all start from same
    # random parameters and gradients are synchronized in backward passes.
    # Therefore, saving it in one process is sufficient.
    torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

  # [***]
  # Use a barrier() to make sure that process 1 loads the model after process
  # 0 saves it.
  dist.barrier()
  # configure map_location properly
  map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
  ddp_model.load_state_dict(
      torch.load(CHECKPOINT_PATH, map_location=map_location))

  for epoch in range(2):
    running_loss = 0.0
    for i, data in enumerate(trainloader):
      inputs, labels = data
      optimizer.zero_grad()
      labels = labels.to(rank)
      outputs = ddp_model(inputs.to(rank))
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()
      
      if i % 2000 == 1999:
        print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
        running_loss = 0.0
        break
  
    # Saving the model after each epoch
    if rank == 0:
      # All processes should see same parameters as they all start from same
      # random parameters and gradients are synchronized in backward passes.
      # Therefore, saving it in one process is sufficient.
      torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

  # [***] Ensure that the model state is saved before any rank proceeds
  dist.barrier()

  # [***] Do not remove the `CHECKPOINT_PATH`, which contains the saved model
  # state -- if you remove it, what are you loading below?
  # Also, do not call `cleanup()` since you are still using the process group
  
  print("Finished Training")

  dataiter = iter(testloader)
  images, labels = dataiter.next()

  # create local model
  net = Net().to(rank)
  ddp_model = DDP(net, device_ids=[rank])

  # Loading our model since we have redefined our 
  # configure map_location properly
  map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
  ddp_model.load_state_dict(
  torch.load(CHECKPOINT_PATH, map_location=map_location))

  outputs = ddp_model(images)
  _, predicted = torch.max(outputs, 1)
  print("Predicted: ", " ".join(f"{classes[predicted[j]]:5s}"
                                for j in range(4)))
  correct = 0
  total = 0
  with torch.no_grad():
    for data in testloader:
      images, labels = data
      labels = labels.to(rank)
      outputs = ddp_model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
  correct_pred = {classname: 0 for classname in classes}
  total_pred = {classname: 0 for classname in classes}

  with torch.no_grad():
    for data in testloader:
      images, labels = data
      labels = labels.to(rank)
      outputs = ddp_model(images)
      _, predictions = torch.max(outputs, 1)

      for label, prediction in zip(labels, predictions):
        labels = labels.to(rank)
        if label == prediction:
          correct_pred[classes[label]] += 1
        total_pred[classes[label]] += 1

  for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %")
1 Like

Regarding your model parallel example, the issue is that you have multiple ranks contesting the same GPUs. If you want to run DDP with model parallelism, you need W * N GPUs where W is the world size and N is the number of model shards (in your case 4 since you want to shard NetMP across dev0, dev1, dev2, and dev3). This means you would need 4 * 4 = 16 GPUs to run your example.

I have adjusted your code to work for 4 GPUs. This means that you only spawn 2 processes (i.e. use a world size of 2). Rank 0 uses GPU0 and GPU1, and rank 1 uses GPU2 and GPU3. I modified NetMP so that all convolutional layers are computed on the first GPU and all linear layers are computed on the second GPU.

Note that you must move the input tensor to the correct device in your forward pass. The tensor must exist on the same device as the layer’s parameters for the computation to work.

class NetMP(nn.Module):
    def __init__(self, dev0, dev1):
      super(NetMP, self).__init__()
      self.dev0 = dev0
      self.dev1 = dev1
      self.conv1 = nn.Conv2d(3, 6, 5).to(dev0)
      self.pool = nn.MaxPool2d(2, 2).to(dev0)
      self.conv2 = nn.Conv2d(6, 16, 5).to(dev0)
      self.fc1 = nn.Linear(16 * 5 * 5, 120).to(dev1)
      self.fc2 = nn.Linear(120, 84).to(dev1)
      self.fc3 = nn.Linear(84, 10).to(dev1)

    def forward(self, x):
      x = x.to(self.dev0)
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = torch.flatten(x, 1) # flatten all dimensions except batch
      x = x.to(self.dev1)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x

def demo_model_parallel(rank, world_size):
  # For the sake of reducing redundncy let"s implement demo_model_parallel as a
  # version of demo_basic, not demo_checkpoint
  if rank >= 2:
    return
  print(f"Running DDP with model parallel example on rank {rank}.")
  setup(rank, 2)

  dev0 = rank * 2
  dev1 = rank * 2 + 1

  trainset, trainloader, testset, testloader, classes = generate_data()
  net = NetMP(dev0, dev1)
  ddp_model = DDP(net)
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)

  for epoch in range(2):
    running_loss = 0.0
    for i, data in enumerate(trainloader):
      inputs, labels = data
      optimizer.zero_grad()
      labels = labels.to(dev1)
      outputs = ddp_model(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()

      if i % 2000 == 1999:
        print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
        running_loss = 0.0
        break
        
  # [***] Move this outside the `for` loop
  cleanup()

  print("Finished Training")

This gives output like:

Running DDP with model parallel example on rank 1.
[1,  2000] loss: 2.161
[2,  2000] loss: 1.757
Running DDP with model parallel example on rank 0.
[1,  2000] loss: 2.162
[2,  2000] loss: 1.774
Finished Training

2 Likes

@agu - thank you for the VERY thorough answer! It’s clear you put a TON of time into this, from you going in depth to your answers, to rewriting parts of my code, and to implementing it on your end- I am at a loss for words for my gratitude.

Given it’s late on my end, let me reread your responses tomorrow to make sure I fully understand your responses and to give me time to run your code- right now my brain isn’t working. I’ll update this post tomorrow. Thank you again!

@agu- let me break down my comments:
Thanks for your proof on computing the loss over the average of the samples! Very slick and it drives the point home.

I now appreciate the difference between summing over samples (i.e. reduction="sum") vs averaging over the samples (i.e. reduction="mean"). What I’m now stuck with deciding when to use which reduction method; what’s clear is that reduction=“mean” is effective for DDP, but I can’t think of an example of when reduction="sum" would be used… At the end of the day it might be useful to think of the reduction method as a hyperparameter, but again I’m greatly uncertain of this thought as it feels like there should be something more decisive behind whether to use "mean" or "sum."

In the section where you say “In your script, each worker processes the entire training set…this is equivalent to training using a single worker using a dataset that is the concatenation of W copies of the original dataset and using a batch size of W * B,” I now understand why I was sitting there in front of my computer trying to why training seemed to take so long- it was precisely this. Also your direction to use DistributedSampler is exactly the solution- thank you! To further underscore this point, thank you so much for the script including DMACK, DMACK_SINGLE, DDP, and DDP_SINGLE, for including important things like the discount factor to make the runs comparable.

Thank you for reinforcing your code and responses by following up to my questions, including driving home the idea behind minibatch and the effective minibatch.

1 Like

Thank you for clearing up the use and placement of dist.barrier() and cleanup(). That being said, shouldn’t the code include cleanup() at the end of the function (and hence the train/test method) as below? I marked places I changed with [&&&]

def demo_checkpoint(rank, world_size):
  print(f"Running checkpoint DDP example on rank {rank}.")
  setup(rank, world_size)

  trainset, trainloader, testset, testloader, classes = generate_data()
  net = Net().to(rank)
  ddp_model = DDP(net, device_ids=[rank])
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)

  CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
  if rank == 0:
    # All processes should see same parameters as they all start from same
    # random parameters and gradients are synchronized in backward passes.
    # Therefore, saving it in one process is sufficient.
    torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

  # [***]
  # Use a barrier() to make sure that process 1 loads the model after process
  # 0 saves it.
  dist.barrier()
  # configure map_location properly
  map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
  ddp_model.load_state_dict(
      torch.load(CHECKPOINT_PATH, map_location=map_location))

  for epoch in range(2):
    running_loss = 0.0
    for i, data in enumerate(trainloader):
      inputs, labels = data
      optimizer.zero_grad()
      labels = labels.to(rank)
      outputs = ddp_model(inputs.to(rank))
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()
      
      if i % 2000 == 1999:
        print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
        running_loss = 0.0
        break
  
    # Saving the model after each epoch
    if rank == 0:
      # All processes should see same parameters as they all start from same
      # random parameters and gradients are synchronized in backward passes.
      # Therefore, saving it in one process is sufficient.
      torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

  # [***] Ensure that the model state is saved before any rank proceeds
  dist.barrier()

  # [***] Do not remove the `CHECKPOINT_PATH`, which contains the saved model
  # state -- if you remove it, what are you loading below?
  # Also, do not call `cleanup()` since you are still using the process group
  
  print("Finished Training")

  dataiter = iter(testloader)
  images, labels = dataiter.next()

  # create local model
  net = Net().to(rank)
  ddp_model = DDP(net, device_ids=[rank])

  # Loading our model since we have redefined our 
  # configure map_location properly
  map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
  ddp_model.load_state_dict(
  torch.load(CHECKPOINT_PATH, map_location=map_location))

  outputs = ddp_model(images)
  _, predicted = torch.max(outputs, 1)
  print("Predicted: ", " ".join(f"{classes[predicted[j]]:5s}"
                                for j in range(4)))
  correct = 0
  total = 0
  with torch.no_grad():
    for data in testloader:
      images, labels = data
      labels = labels.to(rank)
      outputs = ddp_model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
  correct_pred = {classname: 0 for classname in classes}
  total_pred = {classname: 0 for classname in classes}

  with torch.no_grad():
    for data in testloader:
      images, labels = data
      labels = labels.to(rank)
      outputs = ddp_model(images)
      _, predictions = torch.max(outputs, 1)

      for label, prediction in zip(labels, predictions):
        labels = labels.to(rank)
        if label == prediction:
          correct_pred[classes[label]] += 1
        total_pred[classes[label]] += 1

  for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f"Accuracy for class: {classname:5s} is {accuracy:.1f} %")

  # [&&&] Move this outside the `for` loop
  cleanup()

Thanks for rewriting the code for 4 GPUs and for introducing me to the idea of model shards and how to break down the math for the needed number of GPUs- things now make more sense. Just a lingering question- when you say “Note that you must move the input tensor to the correct device in your forward pass. The tensor must exist on the same device as the layer’s parameters for the computation to work,” what this seems to imply is that using DistributedSampler while executing demo_model_parallel may be inappropriate because the model shards are placed on different GPUs, and therefore the tensors will not exist on the same device as the layer’s parameters, correct?

1 Like

What I’m now stuck with deciding when to use which reduction method; what’s clear is that reduction=“mean” is effective for DDP, but I can’t think of an example of when reduction="sum" would be used…

I think the general recommendation is to use reduction="mean" since it allows your script to be agnostic to the batch size, while if you were to use reduction="sum", you may need to retune the hyperparameters for a new batch size. There may be case specific reasons for using reduction="sum". PyTorch probably supports it to be flexible as a deep learning framework, but in practice, you can stick to reduction="mean".

That being said, shouldn’t the code include cleanup() at the end of the function (and hence the train/test method) as below?

Great catch. I agree we should add a cleanup() at the end.

what this seems to imply is that using DistributedSampler while executing demo_model_parallel may be inappropriate because the model shards are placed on different GPUs, and therefore the tensors will not exist on the same device as the layer’s parameters, correct?

You can use the DistributedSampler to get inputs in CPU memory and then include code to move the input to the correct GPU in your training loop (i.e. .to(device0) where device0 is the device used by the first layer in your network). Then, in the forward pass, when you need to move from one model shard to the next, you must explicitly move the tensor to the next shard (i.e. .to(device1) where device1 is the device on which the next model shard resides).

1 Like

This is a good general rule- got it!

Got it, and the stuff in bold sounds super tedious and error prone. Good to know that explicitly moving the tensor is the correct step, though.

Thanks for all of your help, @agu! I’m currently marking your first response as the answer though all of your responses have been incredibly helpful. Thank you!