torch.nn.parallel.DistributedDataParallel() problem about "NoneType Error"\ CalledProcessError\backward

In a GAN-based model, contains one generator model and three discriminator model, all the models are wrapped in torch.nn.parallel.DistributedDataParallel() with different argument process_group, the loss function contains two parts, like this:
d_total_loss = d_real_loss + d_fake_loss
and the backpropgation is : d_total_loss.backward()
when I run the program, the error is:


But, when I run d_real_loss.backward() or d_fake_loss.backward(), the program could run normally.
What’s more, I have another problem that is when I use generator_model.train() in my program and run it, there will be an error:

Could you give me some advice to solve these problems?

Hey @lzkzls could you please share a minimal reproduce-able example code?

The first error picture does not seem to be a DDP error. Does the code run correctly without DDP? Looks like the autograd graph generating d_real_loss and d_fake_loss share some operators/parameters.

The second error picture seems to suggest the generator_model is a None object? It will be helpful to see a self-contained repro of this error.

When I use torch.nn.DataParallel(), the code run correctly.
Thank you so much!
Here is a minimal reproduce-able example:

and I run this code with command: sudo CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 DDP_test.py.
note: When I remove the comment of line 77 and 78, there will be an error:

Hey @lzkzls

The following code works for me. I found two errors:

  1. The original code didn’t set local_rank correctly. It needs be read the local_rank argument instead of hardcoding to 0.
  2. For DDP, you need to call forward and backward interleavingly, instead of two forward followed by one backward. This is fixed by letting the forward function of Discriminator taking both fake and real images.
import argparse
import torch, torchvision
import torch.nn as nn
import torch.distributed as dist
import torchvision.transforms as transforms
import torch.optim as optim


#input (1,28,28)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv2 = nn.ModuleList()
        self.conv2.append(nn.Sequential(nn.Conv2d(1, 16, 3, stride=2, padding=1),
                                        nn.BatchNorm2d(16),
                                        nn.LeakyReLU(negative_slope=0.2)
        ))

        self.conv2.append(nn.Sequential(nn.Conv2d(16, 32, 3, stride=2, padding=1),
                                        nn.BatchNorm2d(32),
                                        nn.LeakyReLU(negative_slope=0.2)
                        ))
        self.conv2.append(nn.Sequential(nn.Conv2d(32, 64, 3, stride=2, padding=1),
                                        nn.BatchNorm2d(64),
                                        nn.LeakyReLU(negative_slope=0.2)
        ))
        self.conv2.append(nn.Sequential(nn.Conv2d(64, 1, 3, stride=2),
                                        nn.BatchNorm2d(1),
                                        nn.LeakyReLU(negative_slope=0.2)
        ))
    def forward(self, fake, real):
        for conv_layer in self.conv2:
            fake = conv_layer(fake)
            real = conv_layer(real)

        return fake.view(-1,1), real.view(-1, 1)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.deconv2 = nn.ModuleList()
        self.deconv2.append(nn.Sequential(nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2,padding=1),
                            nn.BatchNorm2d(32),
                            nn.LeakyReLU()
        ))
        self.deconv2.append(nn.Sequential(nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2,padding=1),
                            nn.BatchNorm2d(16),
                            nn.LeakyReLU()
        ))
        self.deconv2.append(nn.Sequential(nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2,padding=1),
                            nn.BatchNorm2d(1),
                            nn.LeakyReLU()
        ))
    def forward(self, x):
        for layer in self.deconv2:
            x = layer(x)

        return x

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--local_world_size", type=int, default=1)
args = parser.parse_args()

local_rank = args.local_rank
dist.init_process_group(backend='nccl', init_method='env://')

disciminator_model = Discriminator()
generator_model = Generator()

torch.cuda.set_device(local_rank)
disciminator_model.cuda(local_rank)
generator_model.cuda(local_rank)

pg1 = dist.new_group(range(dist.get_world_size()))
pg2 = dist.new_group(range(dist.get_world_size()))
disciminator_model = torch.nn.parallel.DistributedDataParallel(disciminator_model, device_ids=[local_rank],
                                                                output_device=local_rank, process_group=pg1)
generator_model = torch.nn.parallel.DistributedDataParallel(generator_model, device_ids=[local_rank],
                                                                output_device=local_rank, process_group=pg2)

# disciminator_model = disciminator_model.train()
# generator_model = generator_model.train()

g_optimizer = optim.Adam(params=generator_model.parameters(), lr=1e-4)
d_optimizer = optim.Adam(params=disciminator_model.parameters(), lr =1e-4)
bcelog_loss = nn.BCEWithLogitsLoss().cuda(local_rank)

train_dataset = torchvision.datasets.MNIST(root='../../data',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
batch_size = 8
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=False,
                                           num_workers=4,
                                           pin_memory=True,
                                           sampler=train_sampler)

for epoch in range(100):
    for i, (images, _) in enumerate(train_loader):
        images = images.cuda(local_rank, non_blocking=True)
        real_tensor = torch.full((batch_size,1), 1, dtype=torch.float32).cuda(local_rank)
        fake_tensor = torch.zeros((batch_size,1), dtype=torch.float32).cuda(local_rank)
        noise_tensor = torch.rand((batch_size, 64, 4, 4))
        gen_image = generator_model(noise_tensor)

        d_fake, d_real = disciminator_model(gen_image, images)
        #d_real = disciminator_model(images)

        d_fake_loss = bcelog_loss(d_fake, fake_tensor)
        d_real_loss = bcelog_loss(d_real, real_tensor)

        d_total_loss = d_fake_loss + d_real_loss

        g_optimizer.zero_grad()
        d_optimizer.zero_grad()

        d_total_loss.backward()
        g_optimizer.step()
        d_optimizer.step()
        if i % 10 == 0:
            print(f"processed {i} images")
    print("current epoch: ", epoch)
1 Like

Thank you! You are so great! What should I do if I have two different Discriminator class?

What should I do if I have two different Discriminator class?

As long as forward and backward on one DDP instance is called alternatively, it should work. So, there are at least two options:

  1. Wrap the two Discriminators into one nn.Module, say CombinedDiscriminator, and then pass the CombinedDiscriminator to DDP ctor.
  2. Create a dedicated DDP instance (with dedicated ProcessGroup instance) for each Discriminator.
1 Like