RuntimeError due to inplace operation in GAN generator architecture with skip connections

I get the following error for a GAN model I am using to perform image colorization. It uses the LAB color space as is common in image colorization. The generator generates the a and b channels for a given L channel. The discriminator is fed all three channels after concatenation.

NOTE: I am using Google Colab, maybe this might be a potential problem? Also, I am using torch version 1.10.0+cu111. I did use a sequential model without skip connections for the generator before this, and I did not have this error then, so I am assuming that is the problem. I can’t quite put my finger on the problem; any help would be appreciated!

This is the full error statement:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 64, 128, 128]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Here is the error stack:

Here are the imports:

from typing import Tuple
from import Dataset, DataLoader
import torchvision.transforms as T
import torch
import numpy as np
import os
import torch.nn as nn
import torchvision.models as models
import torchvision
import torch.nn.functional as functional
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from PIL import Image
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io

from torchvision.transforms.functional import resize

Use any dataset of color images.
I have the following code to get my train, test, and validation images from the folder “Dataset”:

path = "../Dataset/"
paths = np.array(glob.glob(path + "/*.jpg"))
rand_indices = np.random.permutation(len(paths))          # Number of images in dataset
train_indices, val_indices, test_indices = rand_indices[:3600], rand_indices[3600:4000], rand_indices[4000:]
train_paths = paths[train_indices]
val_paths = paths[val_indices]
test_paths = paths[test_indices]

Here is the data loader:

class ColorizeData(Dataset):
    def __init__(self, paths):
        self.input_transform = T.Compose([T.ToTensor(),
                                          T.Normalize((0.5), (0.5))
        self.lab_transform = T.Compose([T.ToTensor(),
                                          T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        self.paths = paths

    def __len__(self) -> int:
        return len(self.paths)
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image =[index]).convert("RGB")
        input_image = self.input_transform(image)
        image_lab = rgb2lab(image)
        image_lab = self.lab_transform(image_lab)
        image_l = image_lab[0, :, :]
        image_ab = image_lab[1:3, :, :]
        return (input_image.float(), image_ab.float(), image_l.float().reshape(1, 256, 256))

Here is the model:

class NetGen(nn.Module):
    def __init__(self):
        super(NetGen, self).__init__()

        self.conv1 = nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False)
        self.bnorm1 = nn.BatchNorm2d(64)
        self.relu1 = nn.LeakyReLU(0.1)

        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
        self.bnorm2 = nn.BatchNorm2d(128)
        self.relu2 = nn.LeakyReLU(0.1)

        self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)
        self.bnorm3 = nn.BatchNorm2d(256)
        self.relu3 = nn.LeakyReLU(0.1)

        self.conv4 = nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False)
        self.bnorm4 = nn.BatchNorm2d(512)
        self.relu4 = nn.LeakyReLU(0.1)

        self.conv5 = nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=False)
        self.bnorm5 = nn.BatchNorm2d(512)
        self.relu5 = nn.LeakyReLU(0.1)

        self.deconv6 = nn.ConvTranspose2d(512, 512, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm6 = nn.BatchNorm2d(512)
        self.relu6 = nn.ReLU()

        self.deconv7 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm7 = nn.BatchNorm2d(256)
        self.relu7 = nn.ReLU()

        self.deconv8 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm8 = nn.BatchNorm2d(128)
        self.relu8 = nn.ReLU()

        self.deconv9 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm9 = nn.BatchNorm2d(64)
        self.relu9 = nn.ReLU()

        self.deconv10 = nn.ConvTranspose2d(64, 2, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.tanh = nn.Tanh()

    def forward(self, x):
        h = x
        h = self.conv1(h)
        h = self.bnorm1(h)
        h = self.relu1(h) 
        pool1 = h

        h = self.conv2(h)
        h = self.bnorm2(h)
        h = self.relu2(h) 
        pool2 = h

        h = self.conv3(h) 
        h = self.bnorm3(h)
        h = self.relu3(h)
        pool3 = h

        h = self.conv4(h) 
        h = self.bnorm4(h)
        h = self.relu4(h)
        pool4 = h

        h = self.conv5(h) 
        h = self.bnorm5(h)
        h = self.relu5(h)

        h = self.deconv6(h)
        h = self.bnorm6(h)
        h = self.relu6(h) 
        h += pool4

        h = self.deconv7(h)
        h = self.bnorm7(h)
        h = self.relu7(h) 
        h += pool3

        h = self.deconv8(h)
        h = self.bnorm8(h)
        h = self.relu8(h)
        h += pool2

        h = self.deconv9(h)
        h = self.bnorm9(h)
        h = self.relu9(h)
        h += pool1

        h = self.deconv10(h)
        h = self.tanh(h) 
        return h

class NetDis(nn.Module):
    def __init__(self):
        super(NetDis, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False),

            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),

            nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False),

            nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False),

            nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=False),

            nn.Conv2d(512, 512, 8, stride=1, padding=0, bias=False),

            nn.Conv2d(512, 1, 1, stride=1, padding=0, bias=False),

    def forward(self, x):
        return self.main(x)

Here is the weight init function:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(, 1.0, 0.02)
        nn.init.constant_(, 0)

Here is the training and validation code:

class Trainer:
    def __init__(self, epochs, batch_size, learning_rate, num_workers):
        self.epochs = epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_workers = num_workers
        self.train_paths = train_paths
        self.val_paths = val_paths        
        self.real_label = 1
        self.fake_label = 0

    def train(self):             
        train_dataset = ColorizeData(paths=self.train_paths)
        train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers,pin_memory=True, drop_last = True)
        # Model
        model_G = NetGen().to(device)
        model_D = NetDis().to(device)


        optimizer_G = torch.optim.Adam(model_G.parameters(),
                             lr=self.learning_rate, betas=(0.5, 0.999),
                             eps=1e-8, weight_decay=0)
        optimizer_D = torch.optim.Adam(model_D.parameters(),
                             lr=self.learning_rate, betas=(0.5, 0.999),
                             eps=1e-8, weight_decay=0)
        criterion = nn.BCELoss()
        L1 = nn.L1Loss()


        # train loop
        for epoch in range(self.epochs):
            print("Starting Training Epoch " + str(epoch + 1))
            for i, data in enumerate(tqdm(train_dataloader)):                                                    
                inputs, input_ab, input_l = data
                inputs =
                input_ab =
                input_l =

                label = torch.full((self.batch_size,), self.real_label, dtype=torch.float, device=device)
                output = model_D([input_l, input_ab], dim=1))
                errD_real = criterion(torch.squeeze(output), label)

                fake = model_G(input_l)

                output = model_D([input_l, fake.detach()], dim=1))
                errD_fake = criterion(torch.squeeze(output), label)
                errD = errD_real + errD_fake

                output = model_D([input_l, fake], dim=1))
                errG = criterion(torch.squeeze(output), label)
                errG_L1 = L1(fake.view(fake.size(0),-1), input_ab.view(input_ab.size(0),-1))
                errG = errG + 100 * errG_L1

            print(f'Training: Epoch {epoch + 1} \t\t Discriminator Loss: {\
                errD / len(train_dataloader)}  \t\t Generator Loss: {\
                errG / len(train_dataloader)}')
            if (epoch + 1) % 1 == 0:
                errD_val, errG_val, val_len = self.validate(model_D, model_G, criterion, L1)
                print(f'Validation: Epoch {epoch + 1} \t\t Discriminator Loss: {\
                        errD_val / val_len}  \t\t Generator Loss: {\
                        errG_val / val_len}')
  , '../Results/Model_GAN/Generator/saved_model_' + str(epoch + 1) + '.pth')
  , '../Results/Model_GAN/Discriminator/saved_model_' + str(epoch + 1) + '.pth')

    def validate(self, model_D, model_G, criterion, L1):

        with torch.no_grad():
            valid_loss = 0.0
            val_dataset = ColorizeData(paths=self.val_paths)
            val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, drop_last = True)
            for i, data in enumerate(val_dataloader):
                inputs, input_ab, input_l = data
                inputs =
                input_ab =
                input_l =

                label = torch.full((self.batch_size,), self.real_label, dtype=torch.float, device=device)
                output = model_D([input_l, input_ab], dim=1))
                errD_real = criterion(torch.squeeze(output), label)

                fake = model_G(input_l)
                output = model_D([input_l, fake.detach()], dim=1))
                errD_fake = criterion(torch.squeeze(output), label)
                errD = errD_real + errD_fake

                output = model_D([input_l, fake], dim=1))
                errG = criterion(torch.squeeze(output), label)
                errG_L1 = L1(fake.view(fake.size(0),-1), input_ab.view(input_ab.size(0),-1))
                errG = errG + 100 * errG_L1

        return errD, errG, len(val_dataloader)

Use this to run the pipeline:

trainer = Trainer(epochs = 100, batch_size = 64, learning_rate = 0.0002, num_workers = 2)

I coded the training loop while referring to the PyTorch docs, and it worked when the generator did not have skip connections.

Thank you in advance!

I would have assumed you might be running into this issue.
However, I don’t see this behavior in your code.

Could you replace the inplace skip connections h += poolX with their out-of-place versions h = h + poolX and check, if this would solve the issue?
I guess h is needed for a gradient calculation in some layers, which will break if you modify it inplace.

That fixed it! Thank you so much. I still don’t believe I fully grasp why this happens though. Could you elaborate?

Also, in the thread you mentioned, there is a comment about using squeeze() and view(). I will look into that as well!

Since the out-of-place operation fixed the issue, I don’t think the linked post is related but might still be helpful to gain more insight into potential future issues.

The reason for the issue using inplace operations is described in this post with an example. As you can see, depending on the operation, the outputs are needed to calculate the gradients in the backward pass. If you manipulate them inplace, this won’t be possible anymore and the error is raised.

Also, I using inplace operations would also disallow the JIT to fuse these operations, so you might want to skip them entirely.

Ok, got it. Thanks again, you’ve been a huge help!