Gradient computation - CGAN

Hi. I am trying conditional GAN code. The following error comes up when I run the code:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The solution I found on the discussion forums is to detach fake and real samples.

I cannot understand where exactly should I place .detach() or if it is actually the solution to problem. I tried placing it at some places within the code but the error is still there. Any suggestion will be highly appreciated.

    for batch_idx, sample in enumerate(self.data_loader):
        self.writer.set_step((epoch - 1) * len(self.data_loader) + batch_idx)

        # get data and send them to GPU
        blurred = sample['blurred'].to(self.device)
        sharp = sample['sharp'].to(self.device)

        # get G's output
        deblurred = self.generator(blurred)

        # denormalize
        with torch.no_grad():
            denormalized_blurred = denormalize(blurred)
            denormalized_sharp = denormalize(sharp)
            denormalized_deblurred = denormalize(deblurred)

        if batch_idx % 100 == 0:
            # save blurred, sharp and deblurred image
            self.writer.add_image('blurred', make_grid(denormalized_blurred.cpu()))
            self.writer.add_image('sharp', make_grid(denormalized_sharp.cpu()))
            self.writer.add_image('deblurred', make_grid(denormalized_deblurred.cpu()))

        # get D's output
        sharp_discriminator_out = self.discriminator(sharp)
        # deblurred_discriminator_out = self.discriminator(deblurred)
        deblurred_discriminator_out = self.discriminator(deblurred.detach())

        # set critic_updates
        if self.config['loss']['adversarial'] == 'wgan_gp_loss':
            critic_updates = 5
        else:
            critic_updates = 1

        # train discriminator
        discriminator_loss = 0
        for i in range(critic_updates):
            self.discriminator_optimizer.zero_grad()
           
            # train discriminator on real and fake
            if self.config['loss']['adversarial'] == 'wgan_gp_loss':
                gp_lambda = self.config['others']['gp_lambda']
                alpha = random.random()
                interpolates = alpha * sharp + (1 - alpha) * deblurred
                interpolates_discriminator_out = self.discriminator(interpolates)
                kwargs = {
                    'gp_lambda': gp_lambda,
                    'interpolates': interpolates,
                    'interpolates_discriminator_out': interpolates_discriminator_out,
                    'sharp_discriminator_out': sharp_discriminator_out,
                    'deblurred_discriminator_out': deblurred_discriminator_out
                }
                wgan_loss_d, gp_d = self.adversarial_loss('D', **kwargs)
              
                discriminator_loss_per_update = wgan_loss_d + gp_d

                self.writer.add_scalar('wgan_loss_d', wgan_loss_d.item())
                self.writer.add_scalar('gp_d', gp_d.item())
            elif self.config['loss']['adversarial'] == 'gan_loss':
                kwargs = {
                    'sharp_discriminator_out': sharp_discriminator_out,
                    'deblurred_discriminator_out': deblurred_discriminator_out
                }
                gan_loss_d = self.adversarial_loss('D', **kwargs)
                discriminator_loss_per_update = gan_loss_d

                self.writer.add_scalar('gan_loss_d', gan_loss_d.item())
            else:

                raise NotImplementedError
         
            discriminator_loss_per_update.backward(retain_graph=True)
            discriminator_loss += discriminator_loss_per_update.item()
            self.discriminator_optimizer.step()


        discriminator_loss /= critic_updates

As I read the Error note you have some operation within this function, which breaks the compuational graph.

thank you. I will try to figure out the issue in adversarial_loss