Hi
I am building a VAE-GAN (https://miro.medium.com/max/2992/0*KEmfTtghsCDu6UTb.png). It has 3 different loss functions for each part of the network (encoder, decoder, discriminator) and each loss function is a function of outputs/network weights of the different networks.
Pytorch throws this error RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
when calling the second backward pass.
I think it is because the gradients of the network are being used to calculate the loss functions after they have already been updated. I have tried freezing the network gradients and then backpropping but to no avail.
What is best practise for dealing with this kind of issue?
My code looks like this
for data in trainLoader:
i += 1
X, _ = data
X = X.view(-1, 28*28)
X = X.cuda()
# zero gradients from last step
enc_opt.zero_grad()
dec_opt.zero_grad()
disc_opt.zero_grad()
mu, std = enc(X) # encode data; mu = E[z]
X_tilde = dec(mu) # reconstructed image from training data
_, disc_1_real, disc_2_real = disc(X)
_, disc_1_fake, disc_2_fake = disc(X_tilde)
Z = z_sample(mu.clone().detach(), std.clone().detach()) # sample z using reparameterisation trick
X_p = dec(Z.detach())# randomly generated image
# define losses
L_llike = mse_loss(disc_1_fake, disc_1_real) + mse_loss(disc_2_fake, disc_2_real)# disc featurewise error
# L_llike = mse_loss(X_tilde, X)
L_GAN = GAN_loss(X, X_tilde, X_p, disc)
L_prior = KL_loss(mu, std) # regularization loss
# define network losses
L_enc = L_prior + L_llike
L_dec = gamma*L_llike.clone() - L_GAN
L_disc = L_GAN.clone()
set_requires_grad(enc, True)
set_requires_grad(dec, False)
set_requires_grad(disc, False)
# train encoder
L_enc.backward(retain_graph=True)
enc_opt.step()
set_requires_grad(enc, False)
set_requires_grad(dec, True)
set_requires_grad(disc, False)
# train decoder
L_dec.backward(retain_graph=True)
dec_opt.step()
set_requires_grad(enc, False)
set_requires_grad(dec, False)
set_requires_grad(disc, True)
# train discriminator
L_disc.backward()
disc_opt.step()#