Gradient penalty runs into error even with retain_graph=True (solved)

I am trying to implement the gradient penalty of WGAN-GP. I referred to various PyTorch WGAN-GP repositories (example) and this post, and they all seem to use the similar method to get the gradients: autograd.grad with create_graph=True.

However, I am keep running into RuntimeError: Trying to backward through the graph a second time... Specify retain_graph=True.... Even after using retain_graph=True, I get the same error.

What am I doing wrong/differently?

My implementation:

def gradient_penalty(D, real, fake, cond):
    a = torch.rand_like(real[:, 0]).unsqueeze(1)
    interpolates = (a * real.data + (1-a) * fake.data).requires_grad_(True)
    d_interpolates = D(interpolates, cond.data)

    init = torch.ones_like(d_interpolates)
    grads = torch.autograd.grad(d_interpolates, interpolates, grad_outputs=init,
                                only_inputs=True, create_graph=True, retain_graph=True)[0]
    gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()
    return gp

which I use like this

d_loss = -wgan_loss + lambda * grad_penalty
d_loss.backward()

If I remove the grad_penalty from the objective, my code runs fine.

1 Like

Turns out that my activation layers with inplace=True were causing the problem.

It appears that retain_graph does not work as expected with inplace operations, even if it’s just ReLU, which usually has no issues with autograd.

It would be nice if this were noted somewhere in the documentation, since the error output is misleading.

3 Likes