Does anyone get WGAN-GP running in DataParallel? My trial ends up with:
RuntimeError: arguments are located on different GPUs at /pytorch/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:215
Here is an abstract of what I have done:
class WGAN(torch.nn.Module):
# ...
def forward(self, input):
# ...
return L_true+L_fake, gradients
class WGAN_GP_Loss(torch.nn.Module):
def __init__(self, lambda_gp):
self.lambda_gp = lambda_gp
def forward(self, L_critic, gradients):
L_critic = L_critic.mean()
L_gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return L_critic + self.lambda_gp * L_gp
# ... get input in the training loop
L_critic, grad = wgan(input) # wgan is wrapped by DataParallel
loss = wgan_gp_loss(L_critic, grad)
loss.backward() # trigger error: arguments are located on different GPUs