I am trying to train a GAN with the algorithm below but am getting the following error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [30, 1]], which is output 0 of TBackward, is at version 2; expected version 1 instead.
I know that the problem is due to a bug described here that was fixed in v1.5, but I am not sure how to correctly modify my code to solve the issue. Should I be waiting to take optiG.step()
until after the discriminator loss has been computed? Thanks!
optiG = torch.optim.SGD(G.parameters(), lr=g_lr, momentum=momentum, nesterov=True)
optiD = torch.optim.SGD(D.parameters(), lr=d_lr, momentum=momentum, nesterov=True)
for epoch in range(niters):
# Train Generator
for p in D.parameters():
p.requires_grad = False # turn off computation for D
for _ in range(G_iters):
grid_samp = problem.get_grid_sample()
pred = G(grid_samp)
residuals = problem.get_equation(pred, grid_samp, G)
optiG.zero_grad()
g_loss = criterion(D(fake), real_labels)
g_loss.backward(retain_graph=True)
optiG.step()
# Train Discriminator
for p in D.parameters():
p.requires_grad = True # turn on computation for D
for _ in range(D_iters):
if wgan:
norm_penalty = calc_gradient_penalty(D, real, fake, gp, cuda=False)
else:
norm_penalty = torch.zeros(1)
real_loss = criterion(D(real), real_labels)
fake_loss = criterion(D(fake), fake_labels)
optiD.zero_grad()
d_loss = (real_loss + fake_loss)/2 + norm_penalty
d_loss.backward(retain_graph=True)
optiD.step()