I planned to use multiple GPUs to train the generator of a GAN model for video prediction. The training works with one GPU (using nn.parallel.DataParallel with one device_ids or simply without DataParallel), but gets stuck at .backward if there are multiple device_ids. The relevant part of my code is captured below. Thanks in advance!
device_ids = [0,1,2]
device_discriminator = device_ids[0]G_LR = 1e-4
D_LR = 1e-4generator = nn.parallel.DataParallel(EF(encoder, forecaster), device_ids=device_ids,dim=1).to(f’cuda:{device_ids[0]}‘)
discriminator = Discriminator().to(f’cuda:{device_discriminator}’)criterionG = MSE().to(f’cuda:{device_ids[0]}‘)
criterionD = nn.BCELoss(reduction=‘mean’).to(f’cuda:{device_ids[0]}’)optimizerG = torch.optim.Adam(generator.parameters(), lr=G_LR)
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=D_LR)true_label = torch.tensor(1, dtype=torch.float, device=device_discriminator) # true and fake_label refer to the sequence being real (1) or fake (0)
fake_label = torch.tensor(0, dtype=torch.float, device=device_discriminator)prediction = generator(train_data) #with shape [12, 3, 1, 256, 256], 3 being the batch size which is also equal to the number of devices used for training
true_seq = torch.cat((train_data, train_label), 0)
pred_seq = torch.cat((train_data, prediction), 0)generator.train()
for i in range(pred_seq.shape[1]):
D_output = discriminator(pred_seq[:,[i],...].to(device_discriminator)) mse = criterionG(prediction[:,[i],...], train_label[:,[i],...]).to(device_discriminator) GAN_loss = criterionD(D_output, fake_label) errG = -1*GAN_loss * weight + mse torch.nn.utils.clip_grad_value_(generator.parameters(), clip_value=50.0) for param in generator.parameters(): param.grad = None errG.backward() #STUCK AT HERE optimizerG.step()