Hello, I am trying to train a network using DDP. The architecture of the network is such that it consists of two sub-networks (a, b) and depending on input either only a or only b or both a and b get executed. Things work fine on a single GPU. But when expanding the network to 2 or more GPUS the backward just hangs. Any help is appreciated. Thanks.

Below is the minimal reproducible example.

```
import numpy as np
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.optim as optim
import torch.nn.functional as F
class NetA(nn.Module):
def __init__(self):
super().__init__()
self.a1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.a2 = nn.ReLU(inplace=True)
self.a3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
self.a4 = nn.MaxPool2d(kernel_size=8, stride=8)
self.a5 = nn.Linear(8192, 1)
def forward(self, data):
if data.shape[0] == 0:
return torch.zeros(1).cuda() #to(data['b'])
x = self.a4(self.a3(self.a2(self.a1(data))))
x = self.a5(torch.flatten(x, start_dim=1))
return x
class NetB(nn.Module):
def __init__(self):
super().__init__()
self.b1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.b2 = nn.ReLU(inplace=True)
self.b3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
self.b4 = nn.MaxPool2d(kernel_size=4, stride=4)
self.b5 = nn.Linear(2048, 1)
def forward(self, data):
if data.shape[0] == 0:
return torch.zeros(1).cuda() #to(data['b'])
x = self.b4(self.b3(self.b2(self.b1(data))))
x = self.b5(torch.flatten(x, start_dim=1))
return x
def main2():
mp.set_start_method('spawn')
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend='nccl')
neta = NetA().cuda()
netb = NetB().cuda()
ddp_neta = torch.nn.parallel.DistributedDataParallel(
module=neta,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=True)
ddp_netb = torch.nn.parallel.DistributedDataParallel(
module=netb,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=True)
opt_a = optim.SGD(ddp_neta.parameters(), lr=0.001)
opt_b = optim.SGD(ddp_netb.parameters(), lr=0.001)
print('Finetuneing the network... on gpu ', rank)
for i in range(0,20):
opt_a.zero_grad()
opt_b.zero_grad()
f = np.random.rand()
if f < 0.33:
out_a = ddp_neta(torch.randn(4,3,128,128).to(rank))
out_b = ddp_netb(torch.randn(2,3,32,32).to(rank))
loss_a = F.softplus(out_a).mean()
loss_b = F.softplus(out_b).mean()
#print(i, loss_a, loss_b)
elif f < 0.66 and f > 0.33:
out_b = ddp_netb(torch.randn(0,3, 32, 32).to(rank))
out_a = ddp_neta(torch.randn(6,3,128,128).to(rank))
loss_a = F.softplus(out_a).mean()
loss_b = F.softplus(out_b).mean()
#print(i, ' loss_a ', loss_a)
else:
out_a = ddp_neta(torch.randn(0,3,128,128).to(rank))
out_b = ddp_netb(torch.randn(3,3,32,32).to(rank))
loss_b = F.softplus(out_b).mean()
loss_a = F.softplus(out_a).mean()
#print(i, ' loss_b ', loss_b)
print(i, loss_a, loss_b)
loss_a.backward()
loss_b.backward()
opt_a.step()
opt_b.step()
dist.destroy_process_group()
if __name__ == '__main__':
main2()
```

Any suggestions how to fix this? Do I need to use dist.all_gather and/or dist.all_reduce to run the above snippet on multiple gpus? I found this link https://github.com/pytorch/pytorch/issues/23425 and tried moving the if condition to the forward of wrapper layer containing both NetA and NetB. However, that still seems to hang at the backward step.

Thanks for the help