Scatter not differentiable twice

I get the same error with pytorch 0.2, and using the code below adapted from Marvin Cao.
Does anyone have a solution?

def compute_gradient_penalty(discriminator, mixed, real, fake, LAMBDA=10,
                          use_cuda=True):
 batch_size = real.size(0)
 alpha = torch.rand(batch_size, 1)
 alpha = alpha.expand(real.size())
 alpha = alpha.contiguous()
 alpha = alpha.cuda(async=True) if use_cuda else alpha
 interpolates = alpha * real.data + ((1 - alpha) * fake.data)

 if use_cuda:
     interpolates = interpolates.cuda(async=True)
 interpolates = Variable(interpolates, requires_grad=True)
 disc_interpolates = discriminator(mixed, interpolates)

 gradients = torch.autograd.grad(
     outputs=disc_interpolates, inputs=interpolates,
     grad_outputs=torch.ones(disc_interpolates.size()).cuda(async=True) if use_cuda else torch.ones(disc_interpolates.size()),
     create_graph=True, retain_graph=True, only_inputs=True)[0]

 penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
 return penalty

Discriminator is of type DataParallel and variables are of type autograd…Variable