Reusing graph GAN

Hi there!

I’m having some problems with my GAN code, any help would be much appreciated.

I am implementing an SR model which involves 3 loss functions, a linear combination of two for the generator and one for the discriminator.

G1) Loss function between the generator output and target image.

G2) Loss function A between outputs of the discriminator for real (target image) and fake (generator output) inputs.

D) Loss function B between outputs of the discriminator for real (target image) and fake (generator output) inputs.

G2 and D are distinct loss functions.

Problem I am having is the error below, setting anomaly_detection = True does not give much more information. I am fairly sure the issue stems from optimizer_D.step() acting inplace.

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 4; expected version 3 instead.

Here is my code:

batch_hr = batch["hr"].to(opt.local_rank)                              
batch_lr = batch["lr"].to(opt.local_rank)                    
                                                                                                                                                 
g_hr = generator(batch_lr)                            
                                                                                                                                                                                                                               
d_real = discriminator(batch_hr)                                                          
d_fake = discriminator(g_hr.detach()) # detached from generator graph                
                                                                                                                                                       
loss_D = D(d_real, d_fake)                                    
                                                                                                                                                         
optimizer_D.zero_grad(set_to_none=True)                                                
loss_D.backward(retain_graph=True)                                                     
optimizer_D.step()                                                                     
                                                                                                                                                                                                                                      
loss_G1 = G1(batch_hr, g_hr)                                           
                                                                                                                                                  
d_real = discrim(batch_hr).detach() # detached from discrim graph                   
d_fake = discrim(g_hr)                                                         
                                                                                                                   
loss_G2 = G2(d_real, d_fake)                               
                                                                                                                                                       
loss_G = loss_G1 + loss_G2                                                                                                                           
                                                                     
optimizer_G.zero_grad(set_to_none=True)                                                    
loss_G.backward()                                                                          
optimizer_G.step()                                                                                         

I have tried various rearrangements and combinations of clone(), detach() and retain_graph=True but cannot work around this. Further to this, I would ideally only make one call to the generator and two to the discriminator (real/fake once each) but I am not sure this is possible with this combination of losses.

Thank you in advance for any help or advice!

I think I’ve actually figured this out in the end! My focus on reducing the number of discriminator calls turned out to be a bit misplaced as it was definitely not the slow part of the training iteration. I’ve now got it setup where the discriminator trains and updates before creating new outputs which the generator uses for it’s loss function - basically splitting the code into clear train D, train G blocks as I probably should have done from the start :slight_smile: