WGAN Gradient penalty error even with retain_graph=True

I am trying to do audio synthesis, incorporating a GAN loss to make more realistic acoustic features (i.e. mel spectrograms). As a result, I have a “generator” that synthesizes audio and a “discriminator” that classifies between natural and synthesized audio. Wasserstein GAN with gradient penalty is chosen for the training process of the GAN.

However, when using the gradient penalty, I am finding that I am experiencing RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. when performing the loss_critic.backward(retain_graph=True) below. If I remove the gradient penalty, then the code runs without error.
I have looked at this post and this post, but I don’t think I’m facing this same “in-place” issue.

Here is the code for my training:

generator = Generator()
critic = Critic()
generator_optimizer = torch.optim.Adam(generator.parameters())
critic_optimizer = torch.optim.Adam(critic.parameters())

for step, (input_data, target_data) in enumerate(train_loader):    
    ### Update the Generator.
    for p in critic.parameters():  # Freeze critic parameters.
        p.requires_grad = False

    generator.zero_grad()
    pred_data = generator(input_data)

    mse_loss = mse_loss_func(target_data, pred_data)
    gan_loss = -torch.mean(critic(pred_data))
    generator_loss = mse_loss + 0.05 * gan_loss

    generator_optimizer.zero_grad()
    generator_loss.backward()
    generator_optimizer.step()

    ### Update the Critic.
    for p in critic.parameters():  # Unfreeze critic parameters.
        p.requires_grad = True

    critic.zero_grad()
    critic_real = critic(target_data.detach()) 
    critic_gen = critic(pred_data.detach())
    gp = gradient_penalty(critic, target_data, pred_data)
    loss_critic = -(torch.mean(critic_real) - torch.mean(critic_gen)) + 10 * gp

    critic_optimizer.zero_grad()
    loss_critic.backward(retain_graph=True)  # THIS LINE HAS ERROR.
    critic_optimizer.step()

where the gradient penalty calculation is done as here:

def gradient_penalty(critic, real, fake):
    batch_size, C, L = real.shape
    alpha = torch.rand((batch_size, 1, 1)).repeat(1, C, L)
    interpolated_images = real * alpha + fake * (1 - alpha)

    mixed_scores = critic(interpolated_images)
    gradient = torch.autograd.grad(inputs=interpolated_images,
                                   outputs=mixed_scores,
                                   grad_outputs=torch.ones_like(mixed_scores),
                                   create_graph=True,
                                   retain_graph=True,
                                   only_inputs=True)[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

And if it helps, this is the model of my critic:

class Critic(nn.Module):
    def __init__(self, input_nc=80, ndf=256, n_layers=3, kernel_size=5, stride=2):
        super(Critic, self).__init__()

        sequence = [nn.Conv1d(input_nc, ndf, kernel_size=kernel_size, stride=stride), nn.LeakyReLU(0.2, inplace=False)]
        for n in range(1, n_layers):
            sequence += [nn.Conv1d(ndf, ndf, kernel_size=kernel_size, stride=stride), nn.LeakyReLU(0.2, inplace=False)]
            
        sequence += [nn.Conv1d(ndf, 1, kernel_size=kernel_size, stride=stride)] 
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        out = self.model(input)
        out = torch.mean(out, dim=2)  # Mean pooling layer.
        return out

I am very perplexed in why I am experiencing this error, as I am specifically using retain_graph=True. Additionally, this is my first attempt at using WGAN, so if you notice any errors please let me know! Thank you.

Note: I am using torch==1.5.0.

Spent more time debugging. I haven’t solved the problem yet, but if I take away the generator’s gradient step then the discriminator gradient step doesn’t have any error.

maybe try using:

# Gradient-Penalty hook
def gradient_penelty(
    critic, 
    real: torch.Tensor, 
    fake: torch.Tensor,
    gamma: float = 1.
):
    # randomize alpha, so every time we see a different points
    # on the line connecting real and fake samples.
    alpha = torch.rand(1, 1).to(device)
    alpha = alpha.expand(dcs_real.size())
    alpha = alpha.to(device)
    
    # compute the interpolate points
    x_interpolates = alpha * dcs_real + ((1 - alpha) * dcs_fake)
    
    # make sure the interpolated points are ready for device
    # and set it to requires_grad=True to track gradients.
    x_interpolates = x_interpolates.to(device)
    x_interpolates = torch.autograd.Variable(x_interpolates, requires_grad=True)
    
    # compute the discriminator prediction of x_interpolates.
    # this is necessary for computing their gradients using torch.autograd
    dsc_interpolates = dsc(x_interpolates)
    
    # now we have the interploated points and the discriminator predictions of them,
    # so we can utilize torch.autograd to compute the gradients for us. 
    gradients = torch.autograd.grad(outputs=dsc_interpolates, inputs=x_interpolates,
                              grad_outputs=torch.ones(dsc_interpolates.size()).to(device),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    
    # compute the gradient penalty for K=1,
    # which forces unit gradients.
    # finally, multiply by the regularization parameter 'gamma'.
    K = 1
    gradient_penalty = gamma * ((gradients.norm(2, dim=1) - K) ** 2).mean()
    
    return gradient_penalty

Maybe too late for this post, but it may help anyway.

Didn’t run the code, but I would say the problem is pred_data is not detached in this line.
gp = gradient_penalty(critic, target_data, pred_data)

When back-propagating penalty gradients in the discriminator training step, they will flow all the way back to the generator, but the G’s graph has been released before in generator_loss.backward(). Pytorch says to retain graph in that first backward call. But that’s not what you want. Just call gradient penalty as:

gp = gradient_penalty(critic, target_data, pred_data.detach())

Also, the second backward pass doesn’t need to retain the graph:
loss_critic.backward(retain_graph=True)loss_critic.backward()

and target_data doesn’t need to be detached, it can’t go further back.
critic_real = critic(target_data.detach()) critic_real = critic(target_data)