I get the same error with pytorch 0.2, and using the code below adapted from Marvin Cao.
Does anyone have a solution?
def compute_gradient_penalty(discriminator, mixed, real, fake, LAMBDA=10,
use_cuda=True):
batch_size = real.size(0)
alpha = torch.rand(batch_size, 1)
alpha = alpha.expand(real.size())
alpha = alpha.contiguous()
alpha = alpha.cuda(async=True) if use_cuda else alpha
interpolates = alpha * real.data + ((1 - alpha) * fake.data)
if use_cuda:
interpolates = interpolates.cuda(async=True)
interpolates = Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(mixed, interpolates)
gradients = torch.autograd.grad(
outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(async=True) if use_cuda else torch.ones(disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return penalty
Discriminator is of type DataParallel and variables are of type autograd…Variable