Scatter not differentiable twice

Hey there! I am trying out an implementation of the Improved WGAN. I am unfortunately picking up this error : RuntimeError: Scatter is not differentiable twice when trying to get perform gradient_penalty.backward(). Could someone help? The relevant code is below:

gradient_penalty = get_grad_pen(self.Dis_net, X, X_f.cpu().data, lmbda)

get_grad_pen is defined as follows:

def get_grad_pen(Dis_net, real_data, fake_data, lmbda):
	epsilon	= t.FloatTensor(real_data.size(1), real_data.size(2), real_data.size(3)).uniform_(0, 1)
	interpolated_data	= real_data*(epsilon) + fake_data*(1 - epsilon)
	interpolated_dataV	= V(interpolated_data.cuda(), requires_grad=True)
	gradients	= t.autograd.grad(outputs=Dis_net(interpolated_dataV).mean(0).view(1), inputs=interpolated_dataV, create_graph=True, retain_graph=True, only_inputs=True)[0]
	grad_pen	= ((gradients.norm(2, dim=1) - 1).pow(2)).mean().mul(lmbda)
	return grad_pen

V is torch.autograd.Variable, is case you were wondering.

Also, note that this happens only on a GPU. If the model is transferred to a CPU i.e., interpolated_dataV = V(interpolated_data, requires_grad=True), then this is not observed.

Apparently, I think this only happens if we use multi-GPU support. When I tried without nn.parallel.data_parallel, it seems to be working.

I have the same error message, and it also appears only when using torch.nn.DataParallel. On a single GPU, the code works fine.

This is implemented in master, try upgrading.

I get the same error with pytorch 0.2, and using the code below adapted from Marvin Cao.
Does anyone have a solution?

def compute_gradient_penalty(discriminator, mixed, real, fake, LAMBDA=10,
 batch_size = real.size(0)
 alpha = torch.rand(batch_size, 1)
 alpha = alpha.expand(real.size())
 alpha = alpha.contiguous()
 alpha = alpha.cuda(async=True) if use_cuda else alpha
 interpolates = alpha * + ((1 - alpha) *

 if use_cuda:
     interpolates = interpolates.cuda(async=True)
 interpolates = Variable(interpolates, requires_grad=True)
 disc_interpolates = discriminator(mixed, interpolates)

 gradients = torch.autograd.grad(
     outputs=disc_interpolates, inputs=interpolates,
     grad_outputs=torch.ones(disc_interpolates.size()).cuda(async=True) if use_cuda else torch.ones(disc_interpolates.size()),
     create_graph=True, retain_graph=True, only_inputs=True)[0]

 penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
 return penalty

Discriminator is of type DataParallel and variables are of type autograd…Variable