I am trying to implement the gradient penalty of WGAN-GP. I referred to various PyTorch WGAN-GP repositories (example) and this post, and they all seem to use the similar method to get the gradients: autograd.grad
with create_graph=True
.
However, I am keep running into RuntimeError: Trying to backward through the graph a second time... Specify retain_graph=True...
. Even after using retain_graph=True
, I get the same error.
What am I doing wrong/differently?
My implementation:
def gradient_penalty(D, real, fake, cond):
a = torch.rand_like(real[:, 0]).unsqueeze(1)
interpolates = (a * real.data + (1-a) * fake.data).requires_grad_(True)
d_interpolates = D(interpolates, cond.data)
init = torch.ones_like(d_interpolates)
grads = torch.autograd.grad(d_interpolates, interpolates, grad_outputs=init,
only_inputs=True, create_graph=True, retain_graph=True)[0]
gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()
return gp
which I use like this
d_loss = -wgan_loss + lambda * grad_penalty
d_loss.backward()
If I remove the grad_penalty from the objective, my code runs fine.