Why I occur RuntimeError if I use DistributedDataParallel?

I have encounter RuntimeError when I use DistributedDataParallel:

error
[W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnBatchNormBackward0. Traceback of forward call that caused the error:
  File "main.py", line 489, in <module>
    img_A_id = decoders[1](encoder(CT_patches))
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 886, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/THL5/home/hugpu1/MR2CT/resnet.py", line 212, in forward
    x8 = self.layer4(x7)
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/THL5/home/hugpu1/MR2CT/resnet.py", line 64, in forward
    residual = self.downsample(x)
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 2282, in batch_norm
    return torch.batch_norm(
 (function _print_stack)
0it [00:07, ?it/s]
Traceback (most recent call last):
  File "main.py", line 503, in <module>
    GAN_total_loss.backward()
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 9; expected version 8 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
However, If I comment out the DistributedDataParallel line, everything works fine. my code is as follows:
code
def setmodels(device):
    opt = Option()
    encoder, _ = generate_model(opt)
    encoder = encoder.to(device)
    decoders = [Gen().to(device), Gen().to(device)]
    extractors = [Extractor().to(device), Extractor().to(device)]
    Discriminators = [Dis().to(device), Dis().to(device)]
    if torch.distributed.is_initialized():#change this line to False then works fine.
        encoder = torch.nn.parallel.DistributedDataParallel(encoder, device_ids=[
            torch.distributed.get_rank() % torch.cuda.device_count()],
                                                            output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
                                                            find_unused_parameters=True)
        decoders[0] = torch.nn.parallel.DistributedDataParallel(decoders[0], device_ids=[
            torch.distributed.get_rank() % torch.cuda.device_count()],
                                                                output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
                                                                find_unused_parameters=True)
        decoders[1] = torch.nn.parallel.DistributedDataParallel(decoders[1], device_ids=[
            torch.distributed.get_rank() % torch.cuda.device_count()],
                                                                output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
                                                                find_unused_parameters=True)
        extractors[0] = torch.nn.parallel.DistributedDataParallel(extractors[0], device_ids=[
            torch.distributed.get_rank() % torch.cuda.device_count()],
                                                                  output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
                                                                  find_unused_parameters=True)
        extractors[1] = torch.nn.parallel.DistributedDataParallel(extractors[1], device_ids=[
            torch.distributed.get_rank() % torch.cuda.device_count()],
                                                                  output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
                                                                  find_unused_parameters=True)
        Discriminators[0] = torch.nn.parallel.DistributedDataParallel(Discriminators[0], device_ids=[
            torch.distributed.get_rank() % torch.cuda.device_count()],
                                                                      output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
                                                                      find_unused_parameters=True)
        Discriminators[1] = torch.nn.parallel.DistributedDataParallel(Discriminators[1], device_ids=[
            torch.distributed.get_rank() % torch.cuda.device_count()],
                                                                      output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
                                                                      find_unused_parameters=True)
    return encoder,decoders,extractors,Discriminators
if __name__ == '__main__':
    setEnv()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder,decoders,extractors,Discriminators = setmodels(device)
    data_loader = setDataloader()
    lr={"encoder":0.001,"decoders0":0.01,"decoders1":0.01,"extractors0":0.01,"extractors1":0.01,"Discriminators0":0.01,"Discriminators1":0.01}
    encoder_optimizer, decoder_optimizers, extractor_optimizers, Dis_optimizers = setOptimizers(encoder,decoders,extractors,Discriminators,lr)
    lambda_cycle = 10
    lambda_id = 0.9 * lambda_cycle
    loss_hparam = {"lambda_cycle":lambda_cycle,"lambda_id":lambda_id}
    writer = setTensorboard("./tb_logs")
    hparam = {}
    hparam.update(lr)
    hparam.update(loss_hparam)
    if writer is not None:
        writer.add_hparams(hparam,{"accuracy":1})
    lossfn = nn.MSELoss()
    loss_mae  = nn.L1Loss()
    loss_binary = nn.BCEWithLogitsLoss()
    batch_size = 4
    epochs=200
    valid = torch.ones((batch_size,1)).to(device)
    fake = torch.zeros((batch_size,1)).to(device)
    global_step = 0
    with torch.autograd.set_detect_anomaly(True):
        for epoch in range(epochs):
            tk = tqdm(enumerate(data_loader.load_batch(batch_size)),position=0,leave=True)
            for i,(CT_patches,MR_patches,CT_para,MR_para) in tk:
                CT_patches,MR_patches,CT_para,MR_para = CT_patches.to(device),MR_patches.to(device),CT_para.to(device),MR_para.to(device)
                set_require_grad(decoders[0],False)
                set_require_grad(decoders[1],False)
                d_A_loss_real = loss_binary(Discriminators[0](CT_patches),valid)
                Dis_optimizers[0].zero_grad()
                d_A_loss_real.backward()
                Dis_optimizers[0].step()
                #生成样本损失
                fake_CT  =decoders[1](encoder(MR_patches))
                valid_A = Discriminators[0](fake_CT)
                d_A_loss_fake = loss_binary(valid_A,fake)
                Dis_optimizers[0].zero_grad()
                d_A_loss_fake.backward()
                Dis_optimizers[0].step()
                #训练mean,std预测器

                false_CT_para = extractors[0](CT_patches)
                loss_CTpara = lossfn(false_CT_para,CT_para)
                extractor_optimizers[0].zero_grad()
                loss_CTpara.backward()
                extractor_optimizers[0].step()

                false_MR_para = extractors[1](MR_patches)
                loss_MRpara = lossfn(false_MR_para,MR_para)
                extractor_optimizers[1].zero_grad()
                loss_MRpara.backward()
                extractor_optimizers[1].step()
                 #真实样本损失
                d_B_loss_real = loss_binary(Discriminators[1](MR_patches),valid)
                Dis_optimizers[1].zero_grad()
                d_B_loss_real.backward()
                Dis_optimizers[1].step()
                #生成样本损失
                fake_MR = decoders[0](encoder(CT_patches))
                valid_B = Discriminators[1](fake_MR)
                d_B_loss_fake = loss_binary(valid_B,fake)
                Dis_optimizers[1].zero_grad()
                d_B_loss_fake.backward()
                Dis_optimizers[1].step()
                #训练生成器
                set_require_grad(decoders[0],True)
                set_require_grad(decoders[1],True)
                del fake_CT,valid_A,fake_MR,valid_B
                fake_CT  =decoders[1](encoder(MR_patches))
                valid_A = Discriminators[0](fake_CT)
                fake_MR = decoders[0](encoder(CT_patches))
                valid_B = Discriminators[1](fake_MR)
                reconstr_CT  = decoders[1](encoder(fake_MR))
                reconstr_MR = decoders[0](encoder(fake_CT))
                img_A_id = decoders[1](encoder(CT_patches))
                img_B_id = decoders[0](encoder(MR_patches))
                loss1 =loss_binary(valid_A,valid)
                loss2 = loss_binary(valid_B,valid)
                loss3  =lambda_cycle*loss_mae(reconstr_CT,CT_patches)
                loss4 = lambda_cycle*loss_mae(reconstr_MR,MR_patches)
                loss5 = lambda_id*loss_mae(img_A_id,CT_patches)
                loss6 = lambda_id*loss_mae(img_B_id,MR_patches)
                GAN_total_loss = loss1+loss2+loss3+loss4+loss5+loss6
                encoder_optimizer.zero_grad()
                decoder_optimizers[0].zero_grad()
                decoder_optimizers[1].zero_grad()
                Dis_optimizers[0].zero_grad()
                Dis_optimizers[1].zero_grad()
                GAN_total_loss.backward()
                decoder_optimizers[0].step()
                decoder_optimizers[1].step()
I also put some other unimportant codes:
unimportant code
def setDataloader():
    if torch.distributed.is_initialized():
        world_size = int(os.environ['SLURM_NTASKS'])
        data_loader = DataLoader(torch.distributed.get_rank(), world_size)
    else:
        data_loader = DataLoader()
    return data_loader
def setEnv():
    if torch.distributed.is_available() and torch.cuda.device_count()>1:
        rank = int(os.environ['SLURM_PROCID'])
        local_rank = int(os.environ['SLURM_LOCALID'])
        world_size = int(os.environ['SLURM_NTASKS'])
        ip = get_ip()
        dist_init(ip, rank, local_rank, world_size)
def setOptimizers(encoder,decoders,extractors,Discriminators,lr):
    decoder_optimizers = [optim.Adam(decoders[i].parameters(),lr=lr["decoders"+str(i)]) for i in range(2)]
    extractor_optimizers = [optim.Adam(extractors[i].parameters(),lr=lr["extractors"+str(i)]) for i in range(2)]
    encoder_optimizer = optim.SGD(encoder.parameters(), lr=lr["encoder"], momentum=0.9, weight_decay=1e-3)
    Dis_optimizers = [optim.Adam(Discriminators[i].parameters(),lr["Discriminators"+str(i)]) for i in range(2)]
    return encoder_optimizer,decoder_optimizers,extractor_optimizers,Dis_optimizers


def setTensorboard(path):
    if torch.distributed.is_initialized():
        if torch.distributed.get_rank()==0:
            writer = SummaryWriter(path)
        else:
            writer=None
    else:
        writer = SummaryWriter(path)
    return writer
def recordDict(writer,global_step,**kwargs):
    for key,value in kwargs.items():
        writer.add_scalar(key, value, global_step)
def set_require_grad(model,trainable):
    for param in model.parameters():
        param.requires_grad = trainable
def get_ip():
    slurm_nodelist = os.environ.get("SLURM_NODELIST")
    if slurm_nodelist:
        root_node = slurm_nodelist.split(" ")[0].split(",")[0]
    else:
        root_node = "127.0.0.1"
    if '[' in root_node:
        name, numbers = root_node.split('[', maxsplit=1)
        number = numbers.split(',', maxsplit=1)[0]
        if '-' in number:
            number = number.split('-')[0]

        number = re.sub('[^0-9]', '', number)
        root_node = name + number
    return root_node
def dist_init(host_addr, rank, local_rank, world_size, port=23456):
    host_addr_full = 'tcp://' + host_addr + ':' + str(port)
    torch.distributed.init_process_group("nccl", init_method=host_addr_full,
                                         rank=rank, world_size=world_size)
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(local_rank)
    assert torch.distributed.is_initialized()
Could anyone help me?

Hey @yllgl, this might be caused by a coalesced broadcast on buffers. Does your model contain buffers? If so, do you need DDP to sync buffers for you? If not, you can try setting broadcast_buffers=False in DDP ctor.