Hey all. The GAN code below has been working fine for me in an older version of PyTorch that I’ve been working with for a while (1.2), but recently I updated to PyTorch 1.5, and all of a sudden it no longer runs. I’m now getting the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
I’ve looked at several other threads regarding this error, and for the life of me I can’t seem to find any in-place variable modifications that I’m doing in this code.
This is the relevant code:
logits_real, prob_real, features_real = discriminator(real_imgs)
logits_fake, prob_fake, features_fake = discriminator(gen_imgs)
real_mean = features_real.mean(dim=0)
real_std = features_real.std(dim=0)
# ---------------------
# Train Discriminator
# ---------------------
torch.autograd.set_detect_anomaly(True)
optimizer_d.zero_grad()
# Measure discriminator's ability to classify real from generated samples
real_loss = adversarial_loss(logits_real, valid).mean()
fake_loss = adversarial_loss(
discriminator(gen_imgs.detach())[0], fake
).mean()
d_loss = (real_loss + fake_loss) / 2.0
d_loss.backward(retain_graph=True)
d_acc_real = prob_real.mean()
d_acc_fake = torch.mean(1.0 - prob_fake)
d_acc = (d_acc_real + d_acc_fake) / 2.0
# if d_acc < 0.75:
# optimizer_d.step()
optimizer_d.step()
# Train the generator every n_critic iterations
if i % n_critic == 0:
# -----------------
# Train Generator
# -----------------
optimizer_g.zero_grad()
fake_mean = features_fake.mean(dim=0)
fake_std = features_fake.std(dim=0)
# Loss measures generator's ability to fool the discriminator
# g_loss = adversarial_loss(discriminator(gen_imgs)[0], valid).mean()
g_loss = mse(fake_mean, real_mean) # + mse(fake_std, real_std)
g_loss.backward()
optimizer_g.step()
The error pops up as soon as I call g_loss.backward()
, which leads me to believe the issue has to do with the fake_mean
and real_mean
variables, but as far as I’m aware, I’m not doing any in-place modifications to these? I’m at a bit of a loss here, any help would be appreciated!