Inplace operation error with GAN

I am trying to train a GAN with the algorithm below but am getting the following error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [30, 1]], which is output 0 of TBackward, is at version 2; expected version 1 instead.

I know that the problem is due to a bug described here that was fixed in v1.5, but I am not sure how to correctly modify my code to solve the issue. Should I be waiting to take optiG.step() until after the discriminator loss has been computed? Thanks!

optiG = torch.optim.SGD(G.parameters(), lr=g_lr, momentum=momentum, nesterov=True)
optiD = torch.optim.SGD(D.parameters(), lr=d_lr, momentum=momentum, nesterov=True)

for epoch in range(niters):
        
        # Train Generator
        for p in D.parameters():
            p.requires_grad = False # turn off computation for D

        for _ in range(G_iters):
            grid_samp = problem.get_grid_sample()
            pred = G(grid_samp)
            residuals = problem.get_equation(pred, grid_samp, G)
            optiG.zero_grad()
            g_loss = criterion(D(fake), real_labels)
            g_loss.backward(retain_graph=True)
            optiG.step()

        # Train Discriminator
        for p in D.parameters():
            p.requires_grad = True # turn on computation for D

        for _ in range(D_iters):
            if wgan:
                norm_penalty = calc_gradient_penalty(D, real, fake, gp, cuda=False)
            else:
                norm_penalty = torch.zeros(1)

            real_loss = criterion(D(real), real_labels)
            fake_loss = criterion(D(fake), fake_labels)

            optiD.zero_grad()
            d_loss = (real_loss + fake_loss)/2 + norm_penalty
            d_loss.backward(retain_graph=True)
            optiD.step()

This line of code looks strange:

g_loss.backward(retain_graph=True)

as it would keep the computation graph in G alive. Since you are also not detaching fake in:

fake_loss = criterion(D(fake), fake_labels)

these operations:

d_loss = (real_loss + fake_loss)/2 + norm_penalty
d_loss.backward(retain_graph=True)

would most likely try to backprop through G again (and since it’s already updated, the intermediate activations are stale and the error is raised).
Remove the retain_graph=True usage unless you really need it and detach() the fake tensor when updating D.

1 Like

Awesome, thank you! If I detach the fake tensor when calculating fake_loss, can I remove retain_graph=True from both g_loss.backward() and d_loss.backward()?

I think so. You should definitely try to remove the retain_graph=True usage unless you have a very specific reason to use it. In a lot of cases this attribute is used as a workaround and creates other issues unfortunately, as it’s not fitting into the actual use case.

1 Like