I’m using torch.autograd.grad
function to calculate the “gradients w.r.t. gradients” in order to train a WGAN-GP model (paper). However, after bumping up to v.1.0.1.post2, the buffers in the graph still seem to get freed although I provide the option retain_graph=True
. The same code run on v.1.0.0 works fine.
Here’s my code for calculating the GP (gradient penalty).
def calc_gradient_penalty(real_samps, gen_samps, discriminator):
"""
Calculate the gradient penalty (GP).
"""
# linear interpolation of real / fake images
alpha = random.random() # interpolation constant
samps_interp = (real_samps - gen_samps) * alpha + gen_samps
# set `requires_grad=True` to store the grad value
samps_interp = samps_interp.detach().requires_grad_()
# pass through discriminator
score = discriminator(samps_interp)
# calculate the gradients of scores w.r.t. interpolated images
grads = autograd.grad(
outputs=score,
inputs=samps_interp,
grad_outputs=torch.ones(score.size()).to(device),
create_graph=True,
retain_graph=True)[0]
grad_per_samps = grads.norm(dim=1)
grad_penalty = grad_penalty_coeff * torch.pow(grad_per_samps - 1, 2).mean()
return grad_penalty
and then I calculate the discriminator loss using this gradient penalty term;
...
d_loss_real = discriminator(batch_pairs) # out: [n_batch x 1]
gen = generator(noisy_batch, z)
disc_in_pair = torch.cat((gen, noisy_batch), dim=1)
d_loss_fake = discriminator(disc_in_pair.detach())
# calculate the gradient penalty
grad_penalty = calc_gradient_penalty(batch_pairs, gen, discriminator)
d_loss = d_loss_real.squeeze() - d_loss_fake.squeeze() + grad_penalty.squeeze()
# back-propagate and update
discriminator.zero_grad()
d_loss.backward(ones) # ERROR
d_optim.step()
An error occurs on d_loss.backward(ones)
term, with the message:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
What code changes led to this difference, and how can I work the way to calculate the gradients of gradients (are there other preferred ways to do this?)