I have encounter RuntimeError when I use DistributedDataParallel:
error
[W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnBatchNormBackward0. Traceback of forward call that caused the error:
File "main.py", line 489, in <module>
img_A_id = decoders[1](encoder(CT_patches))
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 886, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/THL5/home/hugpu1/MR2CT/resnet.py", line 212, in forward
x8 = self.layer4(x7)
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/THL5/home/hugpu1/MR2CT/resnet.py", line 64, in forward
residual = self.downsample(x)
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
return F.batch_norm(
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 2282, in batch_norm
return torch.batch_norm(
(function _print_stack)
0it [00:07, ?it/s]
Traceback (most recent call last):
File "main.py", line 503, in <module>
GAN_total_loss.backward()
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/THL5/home/hugpu1/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 9; expected version 8 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
code
def setmodels(device):
opt = Option()
encoder, _ = generate_model(opt)
encoder = encoder.to(device)
decoders = [Gen().to(device), Gen().to(device)]
extractors = [Extractor().to(device), Extractor().to(device)]
Discriminators = [Dis().to(device), Dis().to(device)]
if torch.distributed.is_initialized():#change this line to False then works fine.
encoder = torch.nn.parallel.DistributedDataParallel(encoder, device_ids=[
torch.distributed.get_rank() % torch.cuda.device_count()],
output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
find_unused_parameters=True)
decoders[0] = torch.nn.parallel.DistributedDataParallel(decoders[0], device_ids=[
torch.distributed.get_rank() % torch.cuda.device_count()],
output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
find_unused_parameters=True)
decoders[1] = torch.nn.parallel.DistributedDataParallel(decoders[1], device_ids=[
torch.distributed.get_rank() % torch.cuda.device_count()],
output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
find_unused_parameters=True)
extractors[0] = torch.nn.parallel.DistributedDataParallel(extractors[0], device_ids=[
torch.distributed.get_rank() % torch.cuda.device_count()],
output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
find_unused_parameters=True)
extractors[1] = torch.nn.parallel.DistributedDataParallel(extractors[1], device_ids=[
torch.distributed.get_rank() % torch.cuda.device_count()],
output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
find_unused_parameters=True)
Discriminators[0] = torch.nn.parallel.DistributedDataParallel(Discriminators[0], device_ids=[
torch.distributed.get_rank() % torch.cuda.device_count()],
output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
find_unused_parameters=True)
Discriminators[1] = torch.nn.parallel.DistributedDataParallel(Discriminators[1], device_ids=[
torch.distributed.get_rank() % torch.cuda.device_count()],
output_device=torch.distributed.get_rank() % torch.cuda.device_count(),
find_unused_parameters=True)
return encoder,decoders,extractors,Discriminators
if __name__ == '__main__':
setEnv()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder,decoders,extractors,Discriminators = setmodels(device)
data_loader = setDataloader()
lr={"encoder":0.001,"decoders0":0.01,"decoders1":0.01,"extractors0":0.01,"extractors1":0.01,"Discriminators0":0.01,"Discriminators1":0.01}
encoder_optimizer, decoder_optimizers, extractor_optimizers, Dis_optimizers = setOptimizers(encoder,decoders,extractors,Discriminators,lr)
lambda_cycle = 10
lambda_id = 0.9 * lambda_cycle
loss_hparam = {"lambda_cycle":lambda_cycle,"lambda_id":lambda_id}
writer = setTensorboard("./tb_logs")
hparam = {}
hparam.update(lr)
hparam.update(loss_hparam)
if writer is not None:
writer.add_hparams(hparam,{"accuracy":1})
lossfn = nn.MSELoss()
loss_mae = nn.L1Loss()
loss_binary = nn.BCEWithLogitsLoss()
batch_size = 4
epochs=200
valid = torch.ones((batch_size,1)).to(device)
fake = torch.zeros((batch_size,1)).to(device)
global_step = 0
with torch.autograd.set_detect_anomaly(True):
for epoch in range(epochs):
tk = tqdm(enumerate(data_loader.load_batch(batch_size)),position=0,leave=True)
for i,(CT_patches,MR_patches,CT_para,MR_para) in tk:
CT_patches,MR_patches,CT_para,MR_para = CT_patches.to(device),MR_patches.to(device),CT_para.to(device),MR_para.to(device)
set_require_grad(decoders[0],False)
set_require_grad(decoders[1],False)
d_A_loss_real = loss_binary(Discriminators[0](CT_patches),valid)
Dis_optimizers[0].zero_grad()
d_A_loss_real.backward()
Dis_optimizers[0].step()
#生成样本损失
fake_CT =decoders[1](encoder(MR_patches))
valid_A = Discriminators[0](fake_CT)
d_A_loss_fake = loss_binary(valid_A,fake)
Dis_optimizers[0].zero_grad()
d_A_loss_fake.backward()
Dis_optimizers[0].step()
#训练mean,std预测器
false_CT_para = extractors[0](CT_patches)
loss_CTpara = lossfn(false_CT_para,CT_para)
extractor_optimizers[0].zero_grad()
loss_CTpara.backward()
extractor_optimizers[0].step()
false_MR_para = extractors[1](MR_patches)
loss_MRpara = lossfn(false_MR_para,MR_para)
extractor_optimizers[1].zero_grad()
loss_MRpara.backward()
extractor_optimizers[1].step()
#真实样本损失
d_B_loss_real = loss_binary(Discriminators[1](MR_patches),valid)
Dis_optimizers[1].zero_grad()
d_B_loss_real.backward()
Dis_optimizers[1].step()
#生成样本损失
fake_MR = decoders[0](encoder(CT_patches))
valid_B = Discriminators[1](fake_MR)
d_B_loss_fake = loss_binary(valid_B,fake)
Dis_optimizers[1].zero_grad()
d_B_loss_fake.backward()
Dis_optimizers[1].step()
#训练生成器
set_require_grad(decoders[0],True)
set_require_grad(decoders[1],True)
del fake_CT,valid_A,fake_MR,valid_B
fake_CT =decoders[1](encoder(MR_patches))
valid_A = Discriminators[0](fake_CT)
fake_MR = decoders[0](encoder(CT_patches))
valid_B = Discriminators[1](fake_MR)
reconstr_CT = decoders[1](encoder(fake_MR))
reconstr_MR = decoders[0](encoder(fake_CT))
img_A_id = decoders[1](encoder(CT_patches))
img_B_id = decoders[0](encoder(MR_patches))
loss1 =loss_binary(valid_A,valid)
loss2 = loss_binary(valid_B,valid)
loss3 =lambda_cycle*loss_mae(reconstr_CT,CT_patches)
loss4 = lambda_cycle*loss_mae(reconstr_MR,MR_patches)
loss5 = lambda_id*loss_mae(img_A_id,CT_patches)
loss6 = lambda_id*loss_mae(img_B_id,MR_patches)
GAN_total_loss = loss1+loss2+loss3+loss4+loss5+loss6
encoder_optimizer.zero_grad()
decoder_optimizers[0].zero_grad()
decoder_optimizers[1].zero_grad()
Dis_optimizers[0].zero_grad()
Dis_optimizers[1].zero_grad()
GAN_total_loss.backward()
decoder_optimizers[0].step()
decoder_optimizers[1].step()
unimportant code
def setDataloader():
if torch.distributed.is_initialized():
world_size = int(os.environ['SLURM_NTASKS'])
data_loader = DataLoader(torch.distributed.get_rank(), world_size)
else:
data_loader = DataLoader()
return data_loader
def setEnv():
if torch.distributed.is_available() and torch.cuda.device_count()>1:
rank = int(os.environ['SLURM_PROCID'])
local_rank = int(os.environ['SLURM_LOCALID'])
world_size = int(os.environ['SLURM_NTASKS'])
ip = get_ip()
dist_init(ip, rank, local_rank, world_size)
def setOptimizers(encoder,decoders,extractors,Discriminators,lr):
decoder_optimizers = [optim.Adam(decoders[i].parameters(),lr=lr["decoders"+str(i)]) for i in range(2)]
extractor_optimizers = [optim.Adam(extractors[i].parameters(),lr=lr["extractors"+str(i)]) for i in range(2)]
encoder_optimizer = optim.SGD(encoder.parameters(), lr=lr["encoder"], momentum=0.9, weight_decay=1e-3)
Dis_optimizers = [optim.Adam(Discriminators[i].parameters(),lr["Discriminators"+str(i)]) for i in range(2)]
return encoder_optimizer,decoder_optimizers,extractor_optimizers,Dis_optimizers
def setTensorboard(path):
if torch.distributed.is_initialized():
if torch.distributed.get_rank()==0:
writer = SummaryWriter(path)
else:
writer=None
else:
writer = SummaryWriter(path)
return writer
def recordDict(writer,global_step,**kwargs):
for key,value in kwargs.items():
writer.add_scalar(key, value, global_step)
def set_require_grad(model,trainable):
for param in model.parameters():
param.requires_grad = trainable
def get_ip():
slurm_nodelist = os.environ.get("SLURM_NODELIST")
if slurm_nodelist:
root_node = slurm_nodelist.split(" ")[0].split(",")[0]
else:
root_node = "127.0.0.1"
if '[' in root_node:
name, numbers = root_node.split('[', maxsplit=1)
number = numbers.split(',', maxsplit=1)[0]
if '-' in number:
number = number.split('-')[0]
number = re.sub('[^0-9]', '', number)
root_node = name + number
return root_node
def dist_init(host_addr, rank, local_rank, world_size, port=23456):
host_addr_full = 'tcp://' + host_addr + ':' + str(port)
torch.distributed.init_process_group("nccl", init_method=host_addr_full,
rank=rank, world_size=world_size)
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(local_rank)
assert torch.distributed.is_initialized()