Hi,
I have written the GAN network based on my understanding of the DCGAN tutorial and I am getting a runtime error that I do not fully understand. The code I use to train the networks is given below
Generator_Loss = []
Discriminator_Loss = []
generated_images = []
for epoch in range(args['epochs']):
for i,data in enumerate(celebA_dataLoader):
# Training disriminator on real data
real_data_on_gpu = data[0].to(device)
#set_trace()
y_true = torch.full((args["batch_size"], 1), 1, device=device)
y_pred_real = discriminator_net.forward(real_data_on_gpu)
#set_trace()
real_loss = loss_criterion(y_pred_real, y_true)
discriminator_optimizer.zero_grad()
real_loss.backward()
discriminator_optimizer.step()
# Training discriminator on fake data
latent_variables = torch.as_tensor(rand.randn(args["batch_size"], args["dim_Z"], 1, 1).astype(np.float32), device=device)
fake_data_on_gpu = generator_net.forward(latent_variables)
y_pred_fake = discriminator_net.forward(fake_data_on_gpu)
fake_loss = loss_criterion(y_pred_fake, y_true.fill_(0))
discriminator_optimizer.zero_grad()
fake_loss.backward()
discriminator_optimizer.step()
discriminator_loss = real_loss + fake_loss
# Training generator
y_pred_fake = discriminator_net.forward(fake_data_on_gpu)
generator_loss = loss_criterion(y_pred_fake, y_true.fill_(1))
generator_optimizer.zero_grad()
generator_loss.backward()
generator_optimizer.step()
if i % 50 == 0:
Generator_Loss.append(generator_loss.item())
Discriminator_Loss.append(disrciminator_loss.item())
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G:'
% (epoch, args["epochs"], i, len(dataloader),
generator_loss.item(), discriminator_loss.item()))
if epoch == args["epochs"] - 1:
with torch.no_grad():
fake_images = generator_net(latent_variables).detach.cpu()
generated_images.append(fake_images)
I am getting the following run time error.
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
at the line generator_loss.backward()
I have searched about this error, but I am still not clear how the optimizer works. My understanding is that, when there is a forward pass, a computational graph of the network is created and it gets deleted after a backward pass. In my case I did one forward pass for the generator and when I try to do the backward pass I am getting this error.
Could someone please help me understand why this is happening ?
Warm Regards,
Nirmal