Gradient penalty in WGAN-GP not converging on Multi-GPU

I just found a problem: when gradient penalty (GP) is running on a single GPU, it does converging; however, when I switch to multi-GPU, the GP term never decrease.

Here’s my code:

def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

discriminator = some_network()
discriminator = torch.nn.DataParallel(discriminator)

while True:
    gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, gen_imgs.detach())
    loss_gp = opt.lambda_gp * gradient_penalty
    optimizer_D.zero_grad()
    loss_gp.backward(retain_graph=True)
    optimizer_D.step()
    print(loss_gp.item())
    if loss_gp.item() < 1.:
        break

Note that I always use torch.nn.DataParallel() for the discriminator, but only when I set CUDA_VISIBLE_DEVICES=“0” (or any other GPU ID) in the bash script, the GP can converge. If I set to CUDA_VISIBLE_DEVICES=“0,1”, loss_gp will always wander at the same magnitude and never converge.

Can anyone figure why this is happening and what is wrong with my code?
Thanks!

1 Like