Training with two optimizers - fails in torch 1.5.0, works in 1.1.0

Hi,

I run my old code for training simple GAN in torch 1.5.0. I’ve got the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-58-a9e3fa0eb03a> in <module>()
     22 
     23         optim_G.zero_grad()
---> 24         G_loss.backward()
     25         optim_G.step()
     26 

1 frames
/usr/local/lib/python3.6/dist-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    196                 products. Defaults to ``False``.
    197         """
--> 198         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    199 
    200     def register_hook(self, hook):

/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     98     Variable._execution_engine.run_backward(
     99         tensors, grad_tensors, retain_graph, create_graph,
--> 100         allow_unreachable=True)  # allow_unreachable flag
    101 
    102 

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1024, 1]], which is output 0 of TBackward, is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The same script works just fine in torch 1.1.0. What could cause the problem and how to overcome it? I know that one way would be to optimize just one module at each iteration but I’d prefer to stick to the current training procedure. In addition, adding inplace=False in both ReLU and LeakyReLU does not help.

Here is the simplified code:

class Generator(nn.Module):
    def __init__(self, data_dim, hid_dim, z_dim):
        super(Generator, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(z_dim, hid_dim),
            nn.LeakyReLU(0.2),

            nn.Linear(hid_dim, data_dim),
            nn.Tanh(),
        )
    
    def forward(self, x):
        return self.layers(x)
class Discriminator(nn.Module):
    def __init__(self, data_dim, hid_dim):
        super(Discriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(data_dim, hid_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(hid_dim, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        return self.layers(x)
def losses(DG, DR, eps=1e-6):
    D_loss = torch.log(DR + eps) + torch.log(1 - DG + eps)
    G_loss = torch.log(DG + eps)

    return -torch.mean(D_loss), -torch.mean(G_loss)
for i in range(n_epochs):
    for j, (batch, _) in enumerate(data_loader):
        n_samples = batch.shape[0]
        images = batch.reshape(n_samples, -1).to(device)

        z = 2 * torch.rand(n_samples, z_dim).to(device) - 1

        Gz = G(z)
        DG = D(Gz)
        DR = D(images)

        D_loss, G_loss = losses(DG, DR)
        
        optim_D.zero_grad()
        D_loss.backward(retain_graph=True)
        optim_D.step()

        optim_G.zero_grad()
        G_loss.backward()
        optim_G.step()

Hey,

You can find an explanation on why this happens here: https://github.com/pytorch/pytorch/issues/39141#issuecomment-636881953
The short story is that the code in 1.1.0 was returning wrong gradients. Now it properly raises an error because it cannot compute them.

1 Like

Works like a charm. Thanks a bunch.

1 Like