Hi I am trying to parallelize a GAN model using torch.nn.parallel.data_parallel.
However, once I use more than one GPU, the output is all zero after several iterations.
Let me explain …
I have a G model and a D model. Both take multiple inputs and produce outputs. My implementation of training is like that
data_parallel = partial(torch.nn.parallel.data_parallel, gpus=gpu_list)
D.zero_grad()
for iter in range(N):
# update D
logit_tuple1 = data_parallel(D, (input_tuple1,)) # forward multiple times in parallel
logit_tuple2 = data_parallel(D, (input_tuple2,))
logit_tuple3 = data_parallel(D, (input_tuple2,))
loss = compute_d_loss(logit_tuple1, logit_tuple2, logit_tuple3)
loss.backward()
optimD.step()
# update G
imgs = data_parallel(G, (z,))
img_logit1 = data_parallel(D, (imgs, z1))
img_logit2 = data_parallel(D, (imgs, z2))
loss = compute_d_loss(img_logit1, img_logit2)
loss.backward()
optimG.step()
Only G and D forward use data parallel and others (loss and optimization) are running the the main GPU. I am not very familiar with the restricts of using data_parallel. Note that both D and G are self-defined nn.Module that has relatively complex operation in the forward().
Could anyone help me figure out why my output is getting all zero and loss does not decrease if using multi GPU? I am sure my code is working fine if I set gpu_list = [0]. I was guessing the issue is that I forward D multiple times during D update causes gradient accumulation issue during gather GPU gradients. Thanks!