Stuck as .backward() using DataParallel with GAN

I planned to use multiple GPUs to train the generator of a GAN model for video prediction. The training works with one GPU (using nn.parallel.DataParallel with one device_ids or simply without DataParallel), but gets stuck at .backward if there are multiple device_ids. The relevant part of my code is captured below. Thanks in advance!

device_ids = [0,1,2]
device_discriminator = device_ids[0]

G_LR = 1e-4
D_LR = 1e-4

generator = nn.parallel.DataParallel(EF(encoder, forecaster), device_ids=device_ids,dim=1).to(f’cuda:{device_ids[0]}‘)
discriminator = Discriminator().to(f’cuda:{device_discriminator}’)

criterionG = MSE().to(f’cuda:{device_ids[0]}‘)
criterionD = nn.BCELoss(reduction=‘mean’).to(f’cuda:{device_ids[0]}’)

optimizerG = torch.optim.Adam(generator.parameters(), lr=G_LR)
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=D_LR)

true_label = torch.tensor(1, dtype=torch.float, device=device_discriminator) # true and fake_label refer to the sequence being real (1) or fake (0)
fake_label = torch.tensor(0, dtype=torch.float, device=device_discriminator)

prediction = generator(train_data) #with shape [12, 3, 1, 256, 256], 3 being the batch size which is also equal to the number of devices used for training

true_seq = torch.cat((train_data, train_label), 0)
pred_seq = torch.cat((train_data, prediction), 0)

generator.train()

for i in range(pred_seq.shape[1]):

D_output = discriminator(pred_seq[:,[i],...].to(device_discriminator))

mse = criterionG(prediction[:,[i],...], train_label[:,[i],...]).to(device_discriminator)

GAN_loss = criterionD(D_output, fake_label)

errG = -1*GAN_loss * weight + mse

torch.nn.utils.clip_grad_value_(generator.parameters(), clip_value=50.0)

for param in generator.parameters():
  param.grad = None

errG.backward() #STUCK AT HERE

optimizerG.step()
1 Like