I am trying to do audio synthesis, incorporating a GAN loss to make more realistic acoustic features (i.e. mel spectrograms). As a result, I have a “generator” that synthesizes audio and a “discriminator” that classifies between natural and synthesized audio. Wasserstein GAN with gradient penalty is chosen for the training process of the GAN.
However, when using the gradient penalty, I am finding that I am experiencing
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. when performing the
loss_critic.backward(retain_graph=True) below. If I remove the gradient penalty, then the code runs without error.
I have looked at this post and this post, but I don’t think I’m facing this same “in-place” issue.
Here is the code for my training:
generator = Generator() critic = Critic() generator_optimizer = torch.optim.Adam(generator.parameters()) critic_optimizer = torch.optim.Adam(critic.parameters()) for step, (input_data, target_data) in enumerate(train_loader): ### Update the Generator. for p in critic.parameters(): # Freeze critic parameters. p.requires_grad = False generator.zero_grad() pred_data = generator(input_data) mse_loss = mse_loss_func(target_data, pred_data) gan_loss = -torch.mean(critic(pred_data)) generator_loss = mse_loss + 0.05 * gan_loss generator_optimizer.zero_grad() generator_loss.backward() generator_optimizer.step() ### Update the Critic. for p in critic.parameters(): # Unfreeze critic parameters. p.requires_grad = True critic.zero_grad() critic_real = critic(target_data.detach()) critic_gen = critic(pred_data.detach()) gp = gradient_penalty(critic, target_data, pred_data) loss_critic = -(torch.mean(critic_real) - torch.mean(critic_gen)) + 10 * gp critic_optimizer.zero_grad() loss_critic.backward(retain_graph=True) # THIS LINE HAS ERROR. critic_optimizer.step()
where the gradient penalty calculation is done as here:
def gradient_penalty(critic, real, fake): batch_size, C, L = real.shape alpha = torch.rand((batch_size, 1, 1)).repeat(1, C, L) interpolated_images = real * alpha + fake * (1 - alpha) mixed_scores = critic(interpolated_images) gradient = torch.autograd.grad(inputs=interpolated_images, outputs=mixed_scores, grad_outputs=torch.ones_like(mixed_scores), create_graph=True, retain_graph=True, only_inputs=True) gradient = gradient.view(gradient.shape, -1) gradient_norm = gradient.norm(2, dim=1) gradient_penalty = torch.mean((gradient_norm - 1) ** 2) return gradient_penalty
And if it helps, this is the model of my critic:
class Critic(nn.Module): def __init__(self, input_nc=80, ndf=256, n_layers=3, kernel_size=5, stride=2): super(Critic, self).__init__() sequence = [nn.Conv1d(input_nc, ndf, kernel_size=kernel_size, stride=stride), nn.LeakyReLU(0.2, inplace=False)] for n in range(1, n_layers): sequence += [nn.Conv1d(ndf, ndf, kernel_size=kernel_size, stride=stride), nn.LeakyReLU(0.2, inplace=False)] sequence += [nn.Conv1d(ndf, 1, kernel_size=kernel_size, stride=stride)] self.model = nn.Sequential(*sequence) def forward(self, input): out = self.model(input) out = torch.mean(out, dim=2) # Mean pooling layer. return out
I am very perplexed in why I am experiencing this error, as I am specifically using
retain_graph=True. Additionally, this is my first attempt at using WGAN, so if you notice any errors please let me know! Thank you.
Note: I am using