I have cascaded SR GAN (two generators and two discriminator)
(single generator and discriminator code is working fine, not uploaded here).
The essentials are as follows :
self.no_of_gan = [16, 64, 256]
with g_optim_wrapper.optim_context(self):
batch_outputs = self.generator(batch_inputs)
# list of 2 generator outputs : (B, 3, 64, 64) & (B, 3, 256, 256) : GAN : 4x Super-resolution
all_parsed_losses_g = 0
set_requires_grad(self.discriminator, False)
for i in range(2):
req_size = self.no_of_gan[i+1]
gt_resize = torch.nn.functional.interpolate(batch_gt_data, size=req_size)
# batch_gt_data: B, 3, 256, 256
parsed_losses_g, log_vars_d = self.g_step_with_optim(
batch_outputs=batch_outputs[i], batch_gt_data=gt_resize,
optim_wrapper=optim_wrapper, index=i)
all_parsed_losses_g += parsed_losses_g
log_vars.update(log_vars_d)
The above calculates GAN error (but doesn’t update the parameter though). Then I calculate the discriminator error and update the weights of discriminator.
set_requires_grad(self.discriminator, True)
for i in range(self.n_layers):
req_size = self.no_of_gan[i+1]
gt_resize = torch.nn.functional.interpolate(batch_gt_data, size=req_size)
log_vars_d = self.d_step_with_optim(
batch_outputs=batch_outputs[i].detach(),
batch_gt_data=gt_resize,
optim_wrapper=optim_wrapper, index=i)
log_vars.update(log_vars_d)
set_requires_grad(self.discriminator, False)
all_parsed_losses_g.backward()
set_requires_grad(self.discriminator, True)
This should work as updated weights of discriminator have nothing to do with generator loss backward.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1024, 1]] is at version 2; expected version 1 instead.
I have uploaded the code for error reproducibility:
NOTE1: If I move the generator update step before the discriminator update step, the error is resolved. If updating the discriminator weight is the problem, then why not doing the cascade (1 generator and 1 discriminator) not see the same error?
NOTE2: Updating Generator weights after discriminator works as long as I’m not involving Discriminator based loss. (Pixel loss, Perception error etc. all work fine till I do discriminator loss)