How to run WGAN-GP in multi-gpu (DataParallel)

Does anyone get WGAN-GP running in DataParallel? My trial ends up with:

RuntimeError: arguments are located on different GPUs at /pytorch/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:215

Here is an abstract of what I have done:


class WGAN(torch.nn.Module):
    # ...
    def forward(self, input):
        # ...
        return L_true+L_fake, gradients


class WGAN_GP_Loss(torch.nn.Module):
    def __init__(self, lambda_gp):
        self.lambda_gp = lambda_gp

    def forward(self, L_critic, gradients):
        L_critic = L_critic.mean()
        L_gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return L_critic + self.lambda_gp * L_gp

# ... get input in the training loop
L_critic, grad = wgan(input) # wgan is wrapped by DataParallel
loss = wgan_gp_loss(L_critic, grad)
loss.backward() # trigger error: arguments are located on different GPUs

We fixed some bugs in master over the last week w.r.t. higher order gradients and Multi-GPU. You might need the latest master to unblock yourself:

Sorry for the trouble.

3 Likes