Multiple exits distributed data parallel model issue

Hello PyTorch Community,

Hope everyone is well!

Recently we got some extra GPU’s added to my labs machine since we need to profile certain models for a research project which really quickly overwhelm a single GPU so data parallelisation seemed to be the obvious solution.

Before I explain my issue please note that the model runs on a single GPU (Although, the problem might be in the model but only cause an error in the more “delicate” distributed case).

To be honest I have no idea what is really going on… This is my first time working with the DDP and I have had a dozen or so different DDP related errors which I worked through and finally I ended up with this one which is a model related problem weirdly only occurring on the DDP case.

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def main(rank,s,bs,epochs,world_size):
    setup(rank, world_size)

    dataloader = prepare(s,rank, world_size, bs)

    ddp_model = DDP(my_net(s).to(rank), device_ids=[rank], output_device=rank, find_unused_parameters=False)

    size = len(dataloader.dataset)

    optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01, momentum=0.9)
    critireon = nn.CrossEntropyLoss()

    ddp_model.train()

    for epoch in range(epochs):
        dataloader.sampler.set_epoch(epoch)

        for step, (X, y) in enumerate(dataloader):
            X, y = X.to(rank), y.to(rank)

            preds = []
            for e in ddp_model.module.exits:
                preds.append(ddp_model(X, e))
        
            losses = []
            for y_hat in preds:
                losses.append(critireon(y_hat,y))
            
            loss = losses[0] * 0.15 + losses[1] * 0.15 +losses[2] * 0.6

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            if rank == 0 and step % 50 == 0:
                current = step * len(X)
                for indx,l in enumerate(losses):
                    tloss = l.item()
                    print(f"loss of exit {indx}: {tloss:>7f}  [{current:>5d}/{size:>5d}]")
                print('\n')

    cleanup()

def data_parallel():
    s = int(argv[1])
    bs = int(argv[2])
    epochs = int(argv[3])

    world_size = 3    
    mp.spawn(
        main,
        args=(s,bs,epochs,world_size,),
        nprocs=world_size,
        join=True
    )

The above is the DDP code I wrote for training the model. The training loop exactly the same as the single GPU case (posted below for reference as well as the model).

Note: As you can guess form the above code, the model implements an early existing mechanism. Although the current model, inception V1, does have this early exits to prevent gradient issues, here we are interested in their performance and also the rest of the models have been augmented in a similar way with early exists so its trivial for the model to train in this way. Please also note that the current implementation (reference the model code below) is not very optimal but that is beyond the scope of this :grin:.

The error I am getting while is the following (stderr with detect_anomaly==True):

[W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnBatchNormBackward0. Traceback of forward call that caused the error:
  File "<string>", line 1, in <module>
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 129, in _main
    return self._bootstrap(parent_sentinel)
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/george/edge ai/train_googlenet.py", line 42, in main
    preds.append(ddp_model(X, e))
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 886, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/george/edge ai/google_net.py", line 326, in forward
    x = l(x)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/george/edge ai/google_net.py", line 92, in forward
    x = self.bn(x)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 2282, in batch_norm
    return torch.batch_norm(
 (function _print_stack)
Traceback (most recent call last):
  File "/home/george/edge ai/train_googlenet.py", line 173, in <module>
    data_parallel()
  File "/home/george/edge ai/train_googlenet.py", line 70, in data_parallel
    mp.spawn(
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/george/edge ai/train_googlenet.py", line 50, in main
    loss.backward()
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [128]] is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Although I understand what the issue is, I have but the slightest clue what is causing this since it seems like something which should be also affecting the single GPU model but it does not… Bellow you can find the model arch, the distributed dataloader and the single_GPU training loop. Sorry for the wall of code just making sure everything one might need to help me is present.

Line with the BatchNorm is highlighted in the code with a comment.

Thanks a lot for everyone helping users in the forum. You are all heros.

All the best,
George

EDIT:

I failed to mention this since I have had many issues over the course of this parallelisation journey of mine but I have tried this and I got the following error:

[W python_anomaly_mode.cpp:104] Warning: Error detected in torch::autograd::AccumulateGrad. Traceback of forward call that caused the error:
  File "<string>", line 1, in <module>
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 129, in _main
    return self._bootstrap(parent_sentinel)
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/george/edge ai/train_googlenet.py", line 25, in main
    ddp_model = DDP(my_net(s).to(rank), device_ids=[rank], output_device=rank, find_unused_parameters=True)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 587, in __init__
    self._ddp_init_helper(parameters, expect_sparse_gradient, param_to_name_mapping)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 632, in _ddp_init_helper
    self.reducer = dist.Reducer(
 (function _print_stack)
Traceback (most recent call last):
  File "/home/george/edge ai/train_googlenet.py", line 173, in <module>
    data_parallel()
  File "/home/george/edge ai/train_googlenet.py", line 70, in data_parallel
    mp.spawn(
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/george/edge ai/train_googlenet.py", line 50, in main
    loss.backward()
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 186 has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration. You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further debugging.

When using a static graph I got the following

RuntimeError: Your training graph has changed in this iteration, e.g., one parameter is unused in first iteration, but then got used in the second iteration. this is not compatible with static_graph set to True.

So I assume this also leads to a dead end unless, I am missing something.

A possible solution might be hidden in this:

Please make sure model parameters are not shared across multiple concurrent forward-backward passes

Although I have no clue about which/ or how parameters are shared across the models since each forward backward pass is separate with each other at least in the code I implemented (not sure about the PyTorch backend but I assume some parameter sharing must be going on for the synch of gradients).


Edit:

With debugger on I get the following more specific message

Parameter at index 185 with name e1_fc2.weight has been marked as ready twice. This means that multiple autograd engine hooks have fired for this particular parameter during this iteration.

def prepare(size, rank, world_size, batch_size=32, pin_memory=False, num_workers=0):
    t = transforms.Compose([transforms.Resize((size+10, size+10)),
                        transforms.RandomCrop((size, size)),
                        transforms.RandomHorizontalFlip(p=0.5),
                        transforms.ToTensor()])
                        
    dataset = datasets.CIFAR10(
        root='data',
        train=True,
        download=True,
        transform=t
    )

    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
    
    dataloader = DataLoader(dataset, batch_size=batch_size,
                         pin_memory=pin_memory, num_workers=num_workers,
                        drop_last=False, shuffle=False, sampler=sampler)
    
    return dataloader

class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)

##### THIS CAUSES THE ERROR        
        x = self.bn(x)

        x = F.relu(x)
        return x
        
class my_inception(nn.Module):
    def __init__(self, in_size, g1out, g2hid, g2out, g3hid, g3out, g4out):
        super(my_inception, self).__init__()
        
        self.g1c1x1 = BasicConv2d(in_size, g1out, kernel_size=1, stride=1, padding=0)
        
        self.g2c1x1 = BasicConv2d(in_size, g2hid, kernel_size=1, stride=1, padding=0)
        self.g2c3x3 = BasicConv2d(g2hid, g2out, kernel_size=3, stride=1, padding=1)
        
        self.g3c1x1 = BasicConv2d(in_size, g3hid, kernel_size=1, stride=1, padding=0)
        self.g3c5x5 = BasicConv2d(g3hid, g3out, kernel_size=5, stride=1, padding=2)
        
        self.g4mp = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.g4c1x1 = BasicConv2d(in_size, g4out, kernel_size=1, stride=1, padding=0)
        
    def forward(self, x):
        g1_out = self.g1c1x1(x)

        g2_out = self.g2c1x1(x)
        g2_out = self.g2c3x3(g2_out)
        
        g3_out = self.g3c1x1(x)
        g3_out = self.g3c5x5(g3_out)

        g4_out = self.g4mp(x)
        g4_out = self.g4c1x1(g4_out)

        return torch.cat([g1_out, g2_out, g3_out, g4_out], dim=1) 
                
class my_net(nn.Module):
    def __init__(self, img_size, num_out=10):
        super(my_net, self).__init__()
        
            ### CORE ###
        self.cc11 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        
        self.cmp1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    
        self.clrn1 = nn.LocalResponseNorm(2)
        
        self.cc21 = BasicConv2d(64, 192, kernel_size=1, stride=1, padding=1)
        
        self.cc22 = BasicConv2d(192, 192, kernel_size=3, stride=1, padding=0)
        
        self.cmp2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.clrn2 = nn.LocalResponseNorm(2)
        
        self.cinc3a = my_inception(192, 64, 96, 128, 16, 32, 32)
        
        self.cinc3b = my_inception(256, 128, 128, 192, 32, 96, 64)
        
        self.cmp3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.cinc4a = my_inception(480, 192, 96, 208, 16, 48, 64)
        
        self.cinc4b = my_inception(512, 160, 112, 224, 24, 64, 64)
        
        self.cinc4c = my_inception(512, 128, 128, 256, 24, 64, 64)
        
        self.cinc4d = my_inception(512, 112, 144, 288, 32, 64, 64)
        
        self.cinc4e = my_inception(528, 256, 160, 320, 32, 128, 128)
        
        self.cmp4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.cinc5a = my_inception(832, 256, 160, 320, 32, 128, 128)
        
        self.cinc5b = my_inception(832, 384, 192, 384, 48, 128, 128)
        
        if img_size == 224:
            self.cap = nn.AvgPool2d(kernel_size=7, stride=1, padding=0)
            self.cfl = nn.Flatten()
            self.cdo1 = nn.Dropout(0.4)
            self.cfc1 = nn.Linear(1024, 10)
        elif img_size == 512:
            self.cap = nn.AvgPool2d(kernel_size=14, stride=1, padding=0)
            self.cfl = nn.Flatten()
            self.cdo1 = nn.Dropout(0.4)
            self.cfc1 = nn.Linear(9216, 10)
        elif img_size == 1024:
            self.cap = nn.AvgPool2d(kernel_size=28, stride=1, padding=0)
            self.cfl = nn.Flatten()
            self.cdo1 = nn.Dropout(0.4)
            self.cfc1 = nn.Linear(25600, 10)
        
        self.core = [
            self.cc11,
            self.cmp1 ,
            self.clrn1 ,
            self.cc21,
            self.cc22 ,
            self.cmp2 ,
            self.clrn2,
            self.cinc3a ,
            self.cinc3b,
            
            self.cmp3,
            
            self.cinc4a,
            self.cinc4b,
            self.cinc4c,
            self.cinc4d,
            self.cinc4e,
            
            self.cmp4,
            
            self.cinc5a,
            self.cinc5b,
            
            self.cap,
            
            self.cfl,
            self.cdo1,
            self.cfc1, 
        ]
            # exit 0
        if img_size == 224:
            self.e0_ap = nn.AvgPool2d(kernel_size=5, stride=3, padding=0)
            self.e0_conv = BasicConv2d(512, 128, kernel_size=1, stride=1, padding=0)
            self.e0_flat = nn.Flatten()
            self.e0_fc1 = nn.Linear(2048, 1024)
        elif img_size == 512:
            self.e0_ap = nn.AvgPool2d(kernel_size=10, stride=3, padding=0)
            self.e0_conv = BasicConv2d(512, 128, kernel_size=1, stride=1, padding=0)
            self.e0_flat = nn.Flatten()
            self.e0_fc1 = nn.Linear(8192, 1024)
        elif img_size == 1024:
            self.e0_ap = nn.AvgPool2d(kernel_size=20, stride=3, padding=0)
            self.e0_conv = BasicConv2d(512, 128, kernel_size=1, stride=1, padding=0)
            self.e0_flat = nn.Flatten()
            self.e0_fc1 = nn.Linear(28800, 1024)
        
        
        self.e0_r = nn.ReLU(True)
        self.e0_do = nn.Dropout(0.7)
        self.e0_fc2 = nn.Linear(1024, num_out)
        
        self.exit0 = [
            self.cc11,
            self.cmp1 ,
            self.clrn1 ,
            self.cc21,
            self.cc22 ,
            self.cmp2 ,
            self.clrn2,
            self.cinc3a ,
            self.cinc3b,
            
            self.cmp3,
            
            self.cinc4a,
            
            self.e0_ap, 
            self.e0_conv,

            self.e0_flat,

            self.e0_fc1, 
            self.e0_r,
            self.e0_do,
            self.e0_fc2,
        ]
            # exit 1
        if img_size == 224:
            self.e1_ap = nn.AvgPool2d(kernel_size=5, stride=3, padding=0)
            self.e1_conv = BasicConv2d(528, 128, kernel_size=1, stride=1, padding=0)
            self.e1_flat = nn.Flatten()
            self.e1_fc1 = nn.Linear(2048, 1024)
        elif img_size == 512:
            self.e1_ap = nn.AvgPool2d(kernel_size=10, stride=3, padding=0)
            self.e1_conv = BasicConv2d(528, 128, kernel_size=1, stride=1, padding=0)
            self.e1_flat = nn.Flatten()
            self.e1_fc1 = nn.Linear(8192, 1024)
        elif img_size == 1024:
            self.e1_ap = nn.AvgPool2d(kernel_size=20, stride=3, padding=0)
            self.e1_conv = BasicConv2d(528, 128, kernel_size=1, stride=1, padding=0)
            self.e1_flat = nn.Flatten()
            self.e1_fc1 = nn.Linear(28800, 1024)


        self.e1_r = nn.ReLU(True)
        self.e1_do = nn.Dropout(0.7)
        self.e1_fc2 = nn.Linear(1024, num_out)
        
        self.exit1 = [
            self.cc11,
            self.cmp1 ,
            self.clrn1 ,
            self.cc21,
            self.cc22 ,
            self.cmp2 ,
            self.clrn2,
            self.cinc3a ,
            self.cinc3b,
            
            self.cmp3,
            
            self.cinc4a,
            self.cinc4b,
            self.cinc4c,
            self.cinc4d,
            
            self.e1_ap, 
            self.e1_conv,

            self.e1_flat,

            self.e1_fc1, 
            self.e1_r,
            self.e1_do,
            self.e1_fc2,
        ]
        
        
        self.exits = [self.exit0,
                      self.exit1,
                      self.core
                     ]
        
    def forward(self, x, exit_point=-1):
        if isinstance(exit_point, int):
            layers = self.exits[exit_point]
        else:
            layers = exit_point
        
        for l in layers:
            if isinstance(l, nn.Flatten):
                x = torch.flatten(x, 1)
            else:
                x = l(x)
       
        return x
    
def single_GPU():
    def train_exits(dataloader, model, loss_fn, optimizer):
        size = len(dataloader.dataset)
        model.train()
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            # Compute prediction error
            pred_core = []
            for e in model.exits:
                pred_core.append(model(X, e))
                
            losses = []
            for y_hat in pred_core:
                losses.append(loss_fn(y_hat,y))
            
            loss = losses[0] * 0.15 + losses[1] * 0.15 +losses[2] * 0.6
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch % 50 == 0:
                current = batch * len(X)
                for indx,l in enumerate(losses):
                    tloss = l.item()
                    print(f"loss of exit {indx}: {tloss:>7f}  [{current:>5d}/{size:>5d}]")
                print('\n')
     
     # omitted for simplicity
    def test_exits(dataloader, model, loss_fn=nn.CrossEntropyLoss()):
       
    s = int(argv[1])
    bs = int(argv[2])
    epochs = int(argv[3])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = my_net(s).to(device)

    model.train()

    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    critireon = nn.CrossEntropyLoss()

    log_name = f'google logs size {s} batch_size {bs} epochs {epochs}.txt'

    for t in range(epochs):
        with open(log_name, 'a') as f:
            f.write(f'epoch: {t} \n')
        print('EPOCH:', t)
        train, test = get_data(s, batch_size=bs)
        train_exits(train, model, critireon, optimizer)
        test_exits(test, model)
        
        torch.save(model.state_dict(), f'google size {s} epochs {t} bs {bs}.pt')
        try:
            os.remove(f'google size {s} epochs {t-1} bs {bs}.pt')
        except:
            pass

Thanks for posting the question @George_Diamantop From the code snippets it seems like you are doing some early exits, which means not all parameters will participate in the calculation of model output and loss right? Can you try use find_unused_parameters=True when initializing DDP?

Currently, find_unused_parameters=True must be passed into torch.nn.parallel.DistributedDataParallel() initialization if there are parameters that may be unused in the forward pass.

Hello @wanchaol,

Thanks a lot for the quick reply!

I failed to mention this since I have had many issues over the course of this parallelisation journey of mine but I have tried this and I got the following error:

[W python_anomaly_mode.cpp:104] Warning: Error detected in torch::autograd::AccumulateGrad. Traceback of forward call that caused the error:
  File "<string>", line 1, in <module>
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 129, in _main
    return self._bootstrap(parent_sentinel)
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/george/anaconda3/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/george/edge ai/train_googlenet.py", line 25, in main
    ddp_model = DDP(my_net(s).to(rank), device_ids=[rank], output_device=rank, find_unused_parameters=True)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 587, in __init__
    self._ddp_init_helper(parameters, expect_sparse_gradient, param_to_name_mapping)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 632, in _ddp_init_helper
    self.reducer = dist.Reducer(
 (function _print_stack)
Traceback (most recent call last):
  File "/home/george/edge ai/train_googlenet.py", line 173, in <module>
    data_parallel()
  File "/home/george/edge ai/train_googlenet.py", line 70, in data_parallel
    mp.spawn(
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/george/edge ai/train_googlenet.py", line 50, in main
    loss.backward()
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/george/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 186 has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration. You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further debugging.

When using a static graph I got the following

RuntimeError: Your training graph has changed in this iteration, e.g., one parameter is unused in first iteration, but then got used in the second iteration. this is not compatible with static_graph set to True.

So I assume this also leads to a dead end unless, I am missing something.

A possible solution might be hidden in this:

Please make sure model parameters are not shared across multiple concurrent forward-backward passes

Although I have no clue about which/ or how parameters are shared across the models since each forward backward pass is separate with each other at least in the code I implemented (not sure about the PyTorch backend but I assume some parameter sharing must be going on for the synch of gradients).

I will edit the original post to provide this information as well.

Thanks a lot for taking the time to answer,
George