Torch.autograd.grad function still freeing buffers although retain_graph=True

I’m using torch.autograd.grad function to calculate the “gradients w.r.t. gradients” in order to train a WGAN-GP model (paper). However, after bumping up to v.1.0.1.post2, the buffers in the graph still seem to get freed although I provide the option retain_graph=True. The same code run on v.1.0.0 works fine.

Here’s my code for calculating the GP (gradient penalty).

def calc_gradient_penalty(real_samps, gen_samps, discriminator):
    """
    Calculate the gradient penalty (GP).
    """
    # linear interpolation of real / fake images
    alpha = random.random()  # interpolation constant
    samps_interp = (real_samps - gen_samps) * alpha + gen_samps
    # set `requires_grad=True` to store the grad value
    samps_interp = samps_interp.detach().requires_grad_()

    # pass through discriminator
    score = discriminator(samps_interp)

    # calculate the gradients of scores w.r.t. interpolated images
    grads = autograd.grad(
        outputs=score,
        inputs=samps_interp,
        grad_outputs=torch.ones(score.size()).to(device),
        create_graph=True,
        retain_graph=True)[0]

    grad_per_samps = grads.norm(dim=1)
    grad_penalty = grad_penalty_coeff * torch.pow(grad_per_samps - 1, 2).mean()
    return grad_penalty

and then I calculate the discriminator loss using this gradient penalty term;

...
d_loss_real = discriminator(batch_pairs)  # out: [n_batch x 1]

gen = generator(noisy_batch, z)
disc_in_pair = torch.cat((gen, noisy_batch), dim=1)
d_loss_fake = discriminator(disc_in_pair.detach())

# calculate the gradient penalty
grad_penalty = calc_gradient_penalty(batch_pairs, gen, discriminator)
d_loss = d_loss_real.squeeze() - d_loss_fake.squeeze() + grad_penalty.squeeze()

# back-propagate and update
discriminator.zero_grad()
d_loss.backward(ones)  # ERROR
d_optim.step()

An error occurs on d_loss.backward(ones) term, with the message:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

What code changes led to this difference, and how can I work the way to calculate the gradients of gradients (are there other preferred ways to do this?)

Hi,

Could you try to remove all inplace ops from your code?
It could be related to this issue.

1 Like

Thanks! It actually works. Also, I’ve checked that there is a PR that fixes this (for anyone who visits this post later). I might resort to using temporary ReLU class until the bugfix.

Yes, there is a PR but it might take a bit of time before properly fixing this. Using the ReLU class is good !