I am trying to do audio synthesis, incorporating a GAN loss to make more realistic acoustic features (i.e. mel spectrograms). As a result, I have a “generator” that synthesizes audio and a “discriminator” that classifies between natural and synthesized audio. Wasserstein GAN with gradient penalty is chosen for the training process of the GAN.
However, when using the gradient penalty, I am finding that I am experiencing RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
when performing the loss_critic.backward(retain_graph=True)
below. If I remove the gradient penalty, then the code runs without error.
I have looked at this post and this post, but I don’t think I’m facing this same “in-place” issue.
Here is the code for my training:
generator = Generator()
critic = Critic()
generator_optimizer = torch.optim.Adam(generator.parameters())
critic_optimizer = torch.optim.Adam(critic.parameters())
for step, (input_data, target_data) in enumerate(train_loader):
### Update the Generator.
for p in critic.parameters(): # Freeze critic parameters.
p.requires_grad = False
generator.zero_grad()
pred_data = generator(input_data)
mse_loss = mse_loss_func(target_data, pred_data)
gan_loss = -torch.mean(critic(pred_data))
generator_loss = mse_loss + 0.05 * gan_loss
generator_optimizer.zero_grad()
generator_loss.backward()
generator_optimizer.step()
### Update the Critic.
for p in critic.parameters(): # Unfreeze critic parameters.
p.requires_grad = True
critic.zero_grad()
critic_real = critic(target_data.detach())
critic_gen = critic(pred_data.detach())
gp = gradient_penalty(critic, target_data, pred_data)
loss_critic = -(torch.mean(critic_real) - torch.mean(critic_gen)) + 10 * gp
critic_optimizer.zero_grad()
loss_critic.backward(retain_graph=True) # THIS LINE HAS ERROR.
critic_optimizer.step()
where the gradient penalty calculation is done as here:
def gradient_penalty(critic, real, fake):
batch_size, C, L = real.shape
alpha = torch.rand((batch_size, 1, 1)).repeat(1, C, L)
interpolated_images = real * alpha + fake * (1 - alpha)
mixed_scores = critic(interpolated_images)
gradient = torch.autograd.grad(inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
And if it helps, this is the model of my critic:
class Critic(nn.Module):
def __init__(self, input_nc=80, ndf=256, n_layers=3, kernel_size=5, stride=2):
super(Critic, self).__init__()
sequence = [nn.Conv1d(input_nc, ndf, kernel_size=kernel_size, stride=stride), nn.LeakyReLU(0.2, inplace=False)]
for n in range(1, n_layers):
sequence += [nn.Conv1d(ndf, ndf, kernel_size=kernel_size, stride=stride), nn.LeakyReLU(0.2, inplace=False)]
sequence += [nn.Conv1d(ndf, 1, kernel_size=kernel_size, stride=stride)]
self.model = nn.Sequential(*sequence)
def forward(self, input):
out = self.model(input)
out = torch.mean(out, dim=2) # Mean pooling layer.
return out
I am very perplexed in why I am experiencing this error, as I am specifically using retain_graph=True
. Additionally, this is my first attempt at using WGAN, so if you notice any errors please let me know! Thank you.
Note: I am using torch==1.5.0
.