Hi. I am trying conditional GAN code. The following error comes up when I run the code:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
The solution I found on the discussion forums is to detach fake and real samples.
I cannot understand where exactly should I place .detach() or if it is actually the solution to problem. I tried placing it at some places within the code but the error is still there. Any suggestion will be highly appreciated.
for batch_idx, sample in enumerate(self.data_loader):
self.writer.set_step((epoch - 1) * len(self.data_loader) + batch_idx)
# get data and send them to GPU
blurred = sample['blurred'].to(self.device)
sharp = sample['sharp'].to(self.device)
# get G's output
deblurred = self.generator(blurred)
# denormalize
with torch.no_grad():
denormalized_blurred = denormalize(blurred)
denormalized_sharp = denormalize(sharp)
denormalized_deblurred = denormalize(deblurred)
if batch_idx % 100 == 0:
# save blurred, sharp and deblurred image
self.writer.add_image('blurred', make_grid(denormalized_blurred.cpu()))
self.writer.add_image('sharp', make_grid(denormalized_sharp.cpu()))
self.writer.add_image('deblurred', make_grid(denormalized_deblurred.cpu()))
# get D's output
sharp_discriminator_out = self.discriminator(sharp)
# deblurred_discriminator_out = self.discriminator(deblurred)
deblurred_discriminator_out = self.discriminator(deblurred.detach())
# set critic_updates
if self.config['loss']['adversarial'] == 'wgan_gp_loss':
critic_updates = 5
else:
critic_updates = 1
# train discriminator
discriminator_loss = 0
for i in range(critic_updates):
self.discriminator_optimizer.zero_grad()
# train discriminator on real and fake
if self.config['loss']['adversarial'] == 'wgan_gp_loss':
gp_lambda = self.config['others']['gp_lambda']
alpha = random.random()
interpolates = alpha * sharp + (1 - alpha) * deblurred
interpolates_discriminator_out = self.discriminator(interpolates)
kwargs = {
'gp_lambda': gp_lambda,
'interpolates': interpolates,
'interpolates_discriminator_out': interpolates_discriminator_out,
'sharp_discriminator_out': sharp_discriminator_out,
'deblurred_discriminator_out': deblurred_discriminator_out
}
wgan_loss_d, gp_d = self.adversarial_loss('D', **kwargs)
discriminator_loss_per_update = wgan_loss_d + gp_d
self.writer.add_scalar('wgan_loss_d', wgan_loss_d.item())
self.writer.add_scalar('gp_d', gp_d.item())
elif self.config['loss']['adversarial'] == 'gan_loss':
kwargs = {
'sharp_discriminator_out': sharp_discriminator_out,
'deblurred_discriminator_out': deblurred_discriminator_out
}
gan_loss_d = self.adversarial_loss('D', **kwargs)
discriminator_loss_per_update = gan_loss_d
self.writer.add_scalar('gan_loss_d', gan_loss_d.item())
else:
raise NotImplementedError
discriminator_loss_per_update.backward(retain_graph=True)
discriminator_loss += discriminator_loss_per_update.item()
self.discriminator_optimizer.step()
discriminator_loss /= critic_updates
CedricLy
(Cedric Ly)
February 21, 2021, 1:18pm
3
As I read the Error note you have some operation within this function, which breaks the compuational graph.
maryam_hayat:
self.adversarial_loss
maryam_hayat:
self.adversarial_loss
thank you. I will try to figure out the issue in adversarial_loss