RuntimeError:one of the variables needed for gradient computation has been modified by an inplace operation

I am trying to implement a Generative Adversarial Network (GAN)
These are the Generator and Discriminator parts:

class UpConvBlock(nn.Module):
    def __init__(self, n_input, n_output, num_classes, k_size=4, stride=2, padding=0,
                 bias=False, dropout_p=0.0, norm=None):
        super(UpConvBlock, self).__init__()
        self.dropout_p = dropout_p
        self.norm = norm
        self.upconv = spectral_norm(nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=bias))
        if norm == "cbn": self.normt = ConditionalBatchNorm2d(n_output, num_classes)
        elif norm == "pixnorm": self.normt = PixelwiseNorm()
        elif norm == "bn": self.normt = nn.BatchNorm2d(n_output)
        self.activ = nn.LeakyReLU(0.05, inplace=True)
        self.dropout = nn.Dropout2d(p=dropout_p)
    
    def forward(self, inputs):
        x0, labels = inputs

        x = self.upconv(x0)
        if self.norm is not None:
            if self.norm == "cbn":
                x = self.activ(self.normt((x, labels)))
            else:
                x = self.activ(self.normt(x))
        if self.dropout_p > 0.0:
            x = self.dropout(x)
        return x


class Generator(nn.Module):
    def __init__(self, nz=128, num_classes=120, channels=3, nfilt=64):
        super(Generator, self).__init__()
        self.nz = nz
        self.num_classes = num_classes
        self.channels = channels
        
        self.label_emb = nn.Embedding(num_classes, nz)
        self.pixelnorm = PixelwiseNorm()
        self.upconv1 = UpConvBlock(2*nz, nfilt*16, num_classes, k_size=4, stride=1, padding=0, norm="cbn", dropout_p=0.15)
        self.upconv2 = UpConvBlock(nfilt*16, nfilt*8, num_classes, k_size=4, stride=2, padding=1, norm="cbn", dropout_p=0.10)
        self.upconv3 = UpConvBlock(nfilt*8, nfilt*4, num_classes, k_size=4, stride=2, padding=1, norm="cbn", dropout_p=0.05)
        self.upconv4 = UpConvBlock(nfilt*4, nfilt*2, num_classes, k_size=4, stride=2, padding=1, norm="cbn", dropout_p=0.05)
        self.upconv5 = UpConvBlock(nfilt*2, nfilt, num_classes, k_size=4, stride=2, padding=1, norm="cbn", dropout_p=0.05)
        self.self_attn = Self_Attn(nfilt)
        self.upconv6 = UpConvBlock(nfilt, 3, num_classes, k_size=3, stride=1, padding=1, norm="cbn")
        self.out_conv = spectral_norm(nn.Conv2d(3, 3, 3, 1, 1, bias=False))  
        self.out_activ = nn.Tanh()
        
    def forward(self, inputs):
        z, labels = inputs
        
        enc = self.label_emb(labels).view((-1, self.nz, 1, 1))
        enc = F.normalize(enc, p=2, dim=1)
        x = torch.cat((z, enc), 1)
        
        x = self.upconv1((x, labels))
        x = self.upconv2((x, labels))
        x = self.upconv3((x, labels))
        x = self.upconv4((x, labels))
        x = self.upconv5((x, labels))
        x = self.self_attn(x)
        x = self.upconv6((x, labels))
        x = self.out_conv(x)
        img = self.out_activ(x)           
        return img
    
    
class Discriminator(nn.Module):
    def __init__(self, num_classes=120, channels=3, nfilt=64):
        super(Discriminator, self).__init__()
        self.channels = channels
        self.num_classes = num_classes

        def down_convlayer(n_input, n_output, k_size=4, stride=2, padding=0, dropout_p=0.0):
            block = [spectral_norm(nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)),
                     nn.BatchNorm2d(n_output),
                     nn.LeakyReLU(0.2, inplace=True),
                    ]
            if dropout_p > 0.0: block.append(nn.Dropout(p=dropout_p))
            return block
        
        self.label_emb = nn.Embedding(num_classes, 64*64)
        self.model = nn.Sequential(
            *down_convlayer(self.channels + 1, nfilt, 4, 2, 1),
            Self_Attn(nfilt),
            
            *down_convlayer(nfilt, nfilt*2, 4, 2, 1, dropout_p=0.10),
            *down_convlayer(nfilt*2, nfilt*4, 4, 2, 1, dropout_p=0.15),
            *down_convlayer(nfilt*4, nfilt*8, 4, 2, 1, dropout_p=0.25),
            
            MinibatchStdDev(),
            spectral_norm(nn.Conv2d(nfilt*8 + 1, 1, 4, 1, 0, bias=False)),
        )

    def forward(self, inputs):
        imgs, labels = inputs

        enc = self.label_emb(labels).view((-1, 1, 64, 64))
        enc = F.normalize(enc, p=2, dim=1)
        x = torch.cat((imgs, enc), 1)   # 4 input feature maps(3rgb + 1label)
        
        out = self.model(x)
        return out.view(-1)

    
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)        
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

This instead is the training part:


for epoch in range(epochs):
    epoch_time = time.perf_counter()
    
    for ii, (real_images, labels) in enumerate(train_loader):
        if real_images.shape[0]!= BATCH_SIZE: continue
        
        if use_soft_noisy_labels:
            real_labels = torch.squeeze(torch.empty((BATCH_SIZE, 1), device=device).uniform_(0.70, 0.95))
            fake_labels = torch.squeeze(torch.empty((BATCH_SIZE, 1), device=device).uniform_(0.05, 0.15))
            for p in np.random.choice(BATCH_SIZE, size=np.random.randint((BATCH_SIZE//8)), replace=False):
                real_labels[p], fake_labels[p] = fake_labels[p], real_labels[p] # swap labels
        else:
            real_labels = torch.full((BATCH_SIZE, 1), 1.0, device=device)
            fake_labels = torch.full((BATCH_SIZE, 1), 0.0, device=device)
        
        ############################
        # (1) Update D network
        ###########################
        netD.zero_grad()

        labels = torch.tensor(labels, device=device).long()
        real_images = real_images.to(device)
        noise = torch.randn(BATCH_SIZE, nz, 1, 1, device=device)
        
        outputR = netD((real_images, labels))

        fake_images = netG((noise, labels))

        outputF = netD((fake_images.detach(), labels))
        errD = (torch.mean((outputR - torch.mean(outputF) - real_labels) ** 2) + 
                torch.mean((outputF - torch.mean(outputR) + real_labels) ** 2))/2
        errD.backward(retain_graph=True)
        optimizerD.step()

        ############################
        # (2) Update G network
        ###########################
        netG.zero_grad()
        
        outputF = netD((fake_images, labels))
        errG = (torch.mean((outputR - torch.mean(outputF) + real_labels) ** 2) +
                torch.mean((outputF - torch.mean(outputR) - real_labels) ** 2))/2
        errG.backward()
        optimizerG.step()
        
        lr_schedulerG.step(epoch)
        lr_schedulerD.step(epoch)

    if epoch % 10 == 0:
        print('%.2fs [%d/%d] Loss_D: %.4f Loss_G: %.4f' % (
              time.perf_counter()-epoch_time, epoch+1, epochs, errD.item(), errG.item()))
        show_generated_img(6)

During training, I get the error reported on the title:

/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-12-1a0c4f899e6b> in <module>
     41         errG = (torch.mean((outputR - torch.mean(outputF) + real_labels) ** 2) +
     42                 torch.mean((outputF - torch.mean(outputR) - real_labels) ** 2))/2
---> 43         errG.backward()
     44         optimizerG.step()
     45 

1 frames
/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    176 
    177 def grad(

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 513, 4, 4]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I would like to know why it cannot calculate the gradient.
If I understand correctly it seems that a tensor is overwritten, but where it happens on the code?

I also tried to set torch.autograd.set_detect_anomaly(True) as suggested, but its output is not explanatory

1 Like

These issues are often raised by using retain_graph=True while it’s not needed and usually added as a workaround for another issue. Could you explain why retain_graph=True is needed here?

I tried to fit this script with a different dataset.
The above code contains the parameter retain_graph=True. I suppose it is due to the fact that outputR (NetD(...) output) appears in the computation of both losses.

As suggested, I set errG.backward(retain_graph=False)

Now I get this error:

/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:21: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-12-f108a3fc945f> in <module>
     41         errG = (torch.mean((outputR - torch.mean(outputF) + real_labels) ** 2) +
     42                 torch.mean((outputF - torch.mean(outputR) - real_labels) ** 2))/2
---> 43         errG.backward(retain_graph=False)
     44         optimizerG.step()
     45 

1 frames
/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    176 
    177 def grad(


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

This error points to a use case where the computation graphs doesn’t seem to be detached in each iteration but instead you are attaching to it. The backward() call then tries to calculate the gradients for multiple iterations and fails since the computation graph from the previous iteration is already deleted.
Check if your use case depends on some recursive logic (e.g. if you are feeding the output of one iteration as the input to the model in the next iteration) and/or if you need to detach some tensors from the previous iteration.

1 Like