Verifying DDP Model Parameter Sychronization

To verify my understanding of DDP’s model parameter synchronization, I starting with a [tutorial snippet][1]. I instrumented the code to save model snapshots before and after each call to backward().

- Ubuntu 20.04
- Pytorch torch-1.7.1-py3.8
- torch.cuda.nccl.version(): 2708
- 2xNvidia GTX Titan
- Single machine, 2 process, one for each of the GPUs

What I expected was:

  1. Within one GPU, the model parameters after back prop would differ
    from their values before the back prop

    Observed: they were equal.

  2. Across the two GPUs, model parameters would be the same after
    back prop, because DDP synchronizes them.

    Observed: they are indeed equal, but given 1., can I believe any sync is occurring?

In a separate experiment I used Wireshark to observe packet level activity between pytorch processes, but none was occurring.

What is the hitch in my thinking? Or the test implementation?

The code below is run via

python src/birdsong/minimal_ddp_launcher.py

Its output is:

python src/birdsong/minimal_ddp_launcher.py
Starting /home/paepcke/EclipseWorkspaces/birds/src/birdsong/minimal_ddp.py[0] of 2
Starting /home/paepcke/EclipseWorkspaces/birds/src/birdsong/minimal_ddp.py[1] of 2
Running basic DDP example on rank 1.
Running basic DDP example on rank 0.
Proc1: saving arrays of before and after models.
Proc0: saving arrays of before and after models.
Rank 1 is done.
Rank 0 is done.
Suspicious: corresponding pre-backward model parms match exactly across all processes
Good: corresponding post-backward model parms match exactly across processes
Suspicious: back prop has no impact on model parms

minimal_ddp.py:

#!/usr/bin/env python

import os
import sys
import copy

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP

class MinimalDDP:
    '''Test whether DDP really does something'''
    
    epochs  = 2
    samples = 3

    #------------------------------------
    # setup
    #-------------------

    def setup(self, rank, world_size):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'

        # initialize the process group
        dist.init_process_group("nccl", rank=rank, world_size=world_size)

    #------------------------------------
    # demo_basic
    #-------------------

    def demo_basic(self, rank, world_size, model_save_dir='/tmp'):
        '''The action: train model; save intermediate states'''
            
        print(f"Running basic DDP example on rank {rank}.")
        self.setup(rank, world_size)
    
        # create model and move it to GPU with id rank
        model = ToyModel().to(rank)
        ddp_model = DDP(model, device_ids=[rank])
    
        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

        # For saving model copies
        # before and after back prop
        # for each loop iteration:
        
        before = []
        after  = []
        
        for _epoch in range(self.epochs):
            for _i in range(self.samples):
                
                optimizer.zero_grad()
                outputs = ddp_model(torch.randn(20, 10).to(rank))
                labels = torch.randn(20, 5).to(rank)
                
                # Copy and save model copies before and
                # after back prop:
                before.append(copy.deepcopy(ddp_model))
                loss_fn(outputs, labels).backward()
                after.append(copy.deepcopy(ddp_model))

                optimizer.step()

                # Clean GPU memory:
                outputs.cpu()
                labels.cpu()

        dist.barrier()

        # Save the state_dirs of all before-prop
        # and after-prop model copies; each in its
        # own file:
        self.save_model_arrs(rank, before, after, model_save_dir)
        
        self.cleanup()
        
        if rank == 0:
            # Using the saved files, 
            # verify that model parameters
            # change, and are synchronized
            # as expected:
            
            self.report_model_diffs()

    #------------------------------------
    # save_model_arrs 
    #-------------------
    
    def save_model_arrs(self, rank, before_arr, after_arr, model_save_dir):
        '''Save state_dict of modesl in arrays to files'''
        
        print(f"Proc{rank}: saving arrays of before and after models.")
        
        for i, (model_before, model_after) in enumerate(zip(before_arr, after_arr)):
            model_before.cpu()
            model_after.cpu()
            torch.save(model_before.state_dict(),
                       os.path.join(model_save_dir, f"before_models_r{rank}_{i}.pth"))
            torch.save(model_after.state_dict(),
                       os.path.join(model_save_dir, f"after_models_r{rank}_{i}.pth"))

    #------------------------------------
    # report_model_diffs 
    #-------------------

    def report_model_diffs(self, model_save_dir='/tmp'):
        '''Check that model parms changed or 
            were synched as expected '''
        
        model_arrs_len = self.epochs * self.samples
        
        # Among GPUs, model parms should differ
        # before backprop... 
        befores_differ_among_GPUs   = True    # that's the hope
        # ... but be synched by DDP after
        afters_differ_among_GPUs    = False   # that's the hope
        
        # Wihin a single GPU, the model should be 
        # changed by the backprop:
        befores_differ_from_afters  = True    # that's the hope
        
        for i in range(model_arrs_len):
            before_path_r0 = os.path.join(model_save_dir, f"before_models_r0_{i}.pth")
            before_path_r1 = os.path.join(model_save_dir, f"before_models_r1_{i}.pth")
            
            after_path_r0 = os.path.join(model_save_dir, f"after_models_r0_{i}.pth")
            after_path_r1 = os.path.join(model_save_dir, f"after_models_r1_{i}.pth")
            
            before_state0 = torch.load(before_path_r0)
            before_state1 = torch.load(before_path_r1)
            
            after_state0 = torch.load(after_path_r0)
            after_state1 = torch.load(after_path_r1)
            
            # The between-GPUs test:
            for (param_tns0, param_tns1) in zip(before_state0, before_state1):
                if before_state0[param_tns0].eq(before_state1[param_tns1]).all():
                    # Dang!
                    befores_differ_among_GPUs = False
            
            for (param_tns0, param_tns1) in zip(after_state0, after_state1):
                if after_state0[param_tns0].ne(after_state1[param_tns1]).any():
                    # Dang!
                    afters_differ_among_GPUs = False
                    
            # The within-GPUs test:
            for (param_tns_pre, param_tns_post) in zip(before_state0, after_state0):
                if before_state0[param_tns_pre].eq(before_state0[param_tns_post]).all():
                    # Dang!
                    befores_differ_from_afters = False
            
        if befores_differ_among_GPUs:
            print("Good: corresponding pre-backward model parms in processes differ")
        else:
            print("Suspicious: corresponding pre-backward model parms match exactly across all processes")
            
        if afters_differ_among_GPUs:
            print("Bad: backward does not seem to broadcast parms")
        else:
            print("Good: corresponding post-backward model parms match exactly across processes")            

        # Within one GPU, model parms before and 
        # after back prop should be different.
        if befores_differ_from_afters:
            print("Good: back prop does change model parms")
        else:
            print("Suspicious: back prop has no impact on model parms") 


    #------------------------------------
    # cleanup 
    #-------------------

    def cleanup(self):
        dist.destroy_process_group()
        print(f"Rank {rank} is done.")
        
# ------------------------ Toy Model ----------

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

# ------------------------ Main ------------
if __name__ == '__main__':

    rank           = int(sys.argv[1])
    world_size     = 2
    model_save_dir = '/tmp'
    min_ddp = MinimalDDP()
    min_ddp.demo_basic(rank, world_size, model_save_dir)

minimal_ddp_launcher.py:

import subprocess
import os

class MinimalDDPLauncher:
   
    def run_demo(self, demo_script, world_size):
        procs = []
        for rank in range(world_size):
            print(f"Starting {demo_script}[{rank}] of {world_size}")
            procs.append(subprocess.Popen([demo_script, str(rank), str(world_size)]))
            
        for proc in procs:
            proc.wait()

# ------------------------ Main ------------
if __name__ == '__main__':

    curr_dir = os.path.dirname(__file__)
    script_path = os.path.join(curr_dir, 'minimal_ddp.py')
    
    launcher = MinimalDDPLauncher()
    launcher.run_demo(script_path, 2)

[1] Getting Started with Distributed Data Parallel — PyTorch Tutorials 1.7.1 documentation

1 Like

Hi, it seems like this is the essential portion of your code that saves params before/after backward. Although, simply calling backward() is not enough to modify the model parameters, you also need to call the optimizer.step() which will actually apply the averaged grads to the parameters.

1 Like

Thank you Rohan! Moving the test to after the optimizer step confirmed that parameters before and after are indeed different as expected. Both with and without DDP.

However, with the code below I still do not see evidence of DDP’s synchronization across two GPUs (single machine, 2 processes). I marked the code of interest; the rest makes the code runnable.

The output is:

python minimal_ddp_launcher.py minimal_across_two_gpus_ddp.py
Starting minimal_across_two_gpus_ddp.py[0] of 2
Starting minimal_across_two_gpus_ddp.py[1] of 2
Running basic DDP on two GPUs same machine: rank 0.
Running basic DDP on two GPUs same machine: rank 1.
Epoch0 batch0: Before states across gpus are equal
Epoch0 batch0: After states across gpus are equal
Epoch0 batch1: Before states across gpus are equal
Epoch0 batch1: After states across gpus are different
Epoch0 batch2: Before states across gpus are different
Epoch0 batch2: After states across gpus are different
Epoch1 batch0: Before states across gpus are different
Epoch1 batch0: After states across gpus are different
Epoch1 batch1: Before states across gpus are different
Epoch1 batch1: After states across gpus are different
Epoch1 batch2: Before states across gpus are different
Epoch1 batch2: After states across gpus are different
Rank 0 is done.
Rank 1 is done.

In minimal_across_two_gpus_ddp.py:

#!/usr/bin/env python

import os
import sys
import copy

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch import randn

from torch.nn.parallel import DistributedDataParallel as DDP

class MinimalDDP:
    '''Test whether DDP really does something'''
    
    epochs  = 2
    batches = 3

    #------------------------------------
    # setup
    #-------------------

    def setup(self, rank, world_size):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'

        # initialize the process group
        dist.init_process_group("nccl", rank=rank, world_size=world_size)

    #------------------------------------
    # demo_basic
    #-------------------

    def demo_basic(self, rank, world_size):
            
        print(f"Running basic DDP on two GPUs same machine: rank {rank}.")
        self.setup(rank, world_size)
    
        # create model and move it to GPU with id rank
        model = ToyModel().to(rank)
        ddp_model = DDP(model, device_ids=[rank])
    
        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

        dist.barrier()
        
        for epoch_num in range(self.epochs):
            for batch_num in range(self.batches):
                
                optimizer.zero_grad()
                outputs = ddp_model(randn(20, 10).to(rank))
                labels = randn(20, 5).to(rank)
                
                #********* Begin Portion of Interest ******
                before_model = ddp_model.cpu()
                before_state = copy.deepcopy(before_model.state_dict())
                if rank == 1:
                    torch.save(before_state, f"/tmp/before_rank1.pth")
                ddp_model.to(rank)
                
                loss_fn(outputs, labels).backward()
                optimizer.step()

                after_model = ddp_model.cpu()
                after_state = after_model.state_dict()
                if rank == 1:
                    torch.save(after_state, f"/tmp/after_rank1.pth")
                ddp_model.to(rank)
                                
                dist.barrier()
                
                # Read the other's before/after states:
                if rank == 0:
                    other_before_state = torch.load(f"/tmp/before_rank1.pth")
                    other_after_state  = torch.load(f"/tmp/after_rank1.pth")                
                
                    # Before states should be different:
                    states_equal = True
                    for before_parm, other_before_parm in zip(other_before_state.values(),
                                                              before_state.values()):
                        if before_parm.ne(other_before_parm).any():
                            states_equal = False
    
                    print(f"Epoch{epoch_num} batch{batch_num}: Before states across gpus are {('equal' if states_equal else 'different')}")


                    # After states should be the same:
                    states_equal = True
                    for after_parm_other, after_parm in zip(other_after_state.values(),
                                                       after_state.values()):
                        if after_parm_other.ne(after_parm).any():
                            states_equal = False
    
                    print(f"Epoch{epoch_num} batch{batch_num}: After states across gpus are {('equal' if states_equal else 'different')}")

                #********* End Portion of Interest ******
                # Clean GPU memory:
                outputs.cpu()
                labels.cpu()

        dist.barrier()

        self.cleanup()

    #------------------------------------
    # cleanup 
    #-------------------

    def cleanup(self):
        dist.destroy_process_group()
        print(f"Rank {rank} is done.")
        
# ------------------------ Toy Model ----------

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

# ------------------------ Main ------------
if __name__ == '__main__':

    # Started via minimal_ddp_launcher.py,
    # which sets the rank:
    
    rank           = int(sys.argv[1])
    world_size     = 2
    min_ddp = MinimalDDP().demo_basic(rank, world_size)

And in minimal_ddp_launcher.py:

import subprocess
import os, sys

class MinimalDDPLauncher:
   
    def run_demo(self, demo_script, world_size):
        procs = []
        for rank in range(world_size):
            print(f"Starting {demo_script}[{rank}] of {world_size}")
            procs.append(subprocess.Popen([demo_script, str(rank), str(world_size)]))
            
        for proc in procs:
            proc.wait()

# ------------------------ Main ------------
if __name__ == '__main__':

    if len(sys.argv) < 2:
        print("Usage: {minimal_within_two_gpus_ddp.py | minimal_across_two_gpus_ddp.py}")
        sys.exit(1) 
    curr_dir = os.path.dirname(__file__)
    script_path = os.path.join(curr_dir, sys.argv[1])
    
    launcher = MinimalDDPLauncher()
    launcher.run_demo(script_path, 2)
1 Like

Further support for the impression that DDP does not average the gradients automatically: if I add

for param in ddp_model.parameters():
    dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
    param.grad.data /= world_size 

right after the backward() operation, the output is what I expect (other than the use of the deprecated op within pytorch):

Running basic DDP on two GPUs same machine: rank 1.
Running basic DDP on two GPUs same machine: rank 0.
/home/paepcke/anaconda3/envs/birds/lib/python3.9/site-packages/torch-1.7.1-py3.9-linux-x86_64.egg/torch/distributed/distributed_c10d.py:142: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
  warnings.warn("torch.distributed.reduce_op is deprecated, please use "
/home/paepcke/anaconda3/envs/birds/lib/python3.9/site-packages/torch-1.7.1-py3.9-linux-x86_64.egg/torch/distributed/distributed_c10d.py:142: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
  warnings.warn("torch.distributed.reduce_op is deprecated, please use "
Epoch0 batch0: Before states across gpus are equal
Epoch0 batch0: After states across gpus are equal
Epoch0 batch1: Before states across gpus are equal
Epoch0 batch1: After states across gpus are equal
Epoch0 batch2: Before states across gpus are equal
Epoch0 batch2: After states across gpus are equal
Epoch1 batch0: Before states across gpus are equal
Epoch1 batch0: After states across gpus are equal
Epoch1 batch1: Before states across gpus are equal
Epoch1 batch1: After states across gpus are equal
Epoch1 batch2: Before states across gpus are equal
Epoch1 batch2: After states across gpus are equal
Rank 0 is done.
Rank 1 is done.

If I set async_op=True the states are different as they are in the absence of the explicit gradient averaging. This effect seems to indicate that the two processes really do need to run in lockstep.

Simply for cut/paste convenience, I attach the runnable code below. Only the above modification was added to the code in the previous reply.

run like this: python minimal_ddp_launcher.py minimal_across_two_gpus_ddp.py

Two files:
=============== CUT: File minimal_across_two_gpus_ddp.py:

class MinimalDDP:
    '''Test whether DDP really does something'''
    
    epochs  = 2
    batches = 3

    #------------------------------------
    # setup
    #-------------------

    def setup(self, rank, world_size):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'

        # initialize the process group
        dist.init_process_group("nccl", rank=rank, world_size=world_size)

    #------------------------------------
    # demo_basic
    #-------------------

    def demo_basic(self, rank, world_size):
            
        print(f"Running basic DDP on two GPUs same machine: rank {rank}.")
        self.setup(rank, world_size)
    
        # create model and move it to GPU with id rank
        model = ToyModel().to(rank)
        ddp_model = DDP(model, device_ids=[rank])
    
        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

        dist.barrier()
        
        for epoch_num in range(self.epochs):
            for batch_num in range(self.batches):
                
                optimizer.zero_grad()
                outputs = ddp_model(randn(20, 10).to(rank))
                labels = randn(20, 5).to(rank)
                
                #********* Begin Portion of Interest ******
                before_model = ddp_model.cpu()
                before_state = copy.deepcopy(before_model.state_dict())
                if rank == 1:
                    torch.save(before_state, f"/tmp/before_rank1.pth")
                ddp_model.to(rank)
                
                loss_fn(outputs, labels).backward()
                #******
                for param in ddp_model.parameters():
                    dist.all_reduce(param.grad.data, 
                                    op=dist.reduce_op.SUM,
                                    async_op=False)
                    param.grad.data /= world_size 
                #******
                optimizer.step()

                after_model = ddp_model.cpu()
                after_state = after_model.state_dict()
                if rank == 1:
                    torch.save(after_state, f"/tmp/after_rank1.pth")
                ddp_model.to(rank)
                                
                dist.barrier()
                
                # Read the other's before/after states:
                if rank == 0:
                    other_before_state = torch.load(f"/tmp/before_rank1.pth")
                    other_after_state  = torch.load(f"/tmp/after_rank1.pth")                
                
                    # Before states should be different:
                    states_equal = True
                    for before_parm, other_before_parm in zip(other_before_state.values(),
                                                              before_state.values()):
                        if before_parm.ne(other_before_parm).any():
                            states_equal = False
    
                    print(f"Epoch{epoch_num} batch{batch_num}: Before states across gpus are {('equal' if states_equal else 'different')}")


                    # After states should be the same:
                    states_equal = True
                    for after_parm_other, after_parm in zip(other_after_state.values(),
                                                       after_state.values()):
                        if after_parm_other.ne(after_parm).any():
                            states_equal = False
    
                    print(f"Epoch{epoch_num} batch{batch_num}: After states across gpus are {('equal' if states_equal else 'different')}")

                #********* End Portion of Interest ******
                # Clean GPU memory:
                outputs.cpu()
                labels.cpu()

        dist.barrier()

        self.cleanup()

    #------------------------------------
    # cleanup 
    #-------------------

    def cleanup(self):
        dist.destroy_process_group()
        print(f"Rank {rank} is done.")
        
# ------------------------ Toy Model ----------

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

# ------------------------ Main ------------
if __name__ == '__main__':

    # Started via minimal_ddp_launcher.py,
    # which sets the rank:
    
    rank           = int(sys.argv[1])
    world_size     = 2
    min_ddp = MinimalDDP().demo_basic(rank, world_size)

========== CUT And file minimal_ddp_launcher.py:

import subprocess
import os, sys

class MinimalDDPLauncher:
   
    def run_demo(self, demo_script, world_size):
        procs = []
        for rank in range(world_size):
            print(f"Starting {demo_script}[{rank}] of {world_size}")
            procs.append(subprocess.Popen([demo_script, str(rank), str(world_size)]))
            
        for proc in procs:
            proc.wait()

# ------------------------ Main ------------
if __name__ == '__main__':

    if len(sys.argv) < 2:
        print("Usage: {minimal_within_two_gpus_ddp.py | minimal_across_two_gpus_ddp.py}")
        sys.exit(1) 
    curr_dir = os.path.dirname(__file__)
    script_path = os.path.join(curr_dir, sys.argv[1])
    
    launcher = MinimalDDPLauncher()
    launcher.run_demo(script_path, 2)
2 Likes