I get the following error for a GAN model I am using to perform image colorization. It uses the LAB color space as is common in image colorization. The generator generates the a and b channels for a given L channel. The discriminator is fed all three channels after concatenation.
NOTE: I am using Google Colab, maybe this might be a potential problem? Also, I am using torch version 1.10.0+cu111. I did use a sequential model without skip connections for the generator before this, and I did not have this error then, so I am assuming that is the problem. I can’t quite put my finger on the problem; any help would be appreciated!
This is the full error statement:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 64, 128, 128]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Here is the error stack:
RuntimeError Traceback (most recent call last)
<ipython-input-20-3435b262f1ae> in <module>()
----> 1 trainer.train()
2 frames
<ipython-input-18-1255f97997c7> in train(self)
87 errG = errG + 100 * errG_L1
88 # Calculate gradients for G
---> 89 errG.backward()
90 # Update G
91 optimizer_G.step() # Update the weights
/usr/local/lib/python3.7/dist-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
305 create_graph=create_graph,
306 inputs=inputs)
--> 307 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
308
309 def register_hook(self, hook):
/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
154 Variable._execution_engine.run_backward(
155 tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 156 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
157
158
Here are the imports:
from typing import Tuple
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch
import numpy as np
import os
import torch.nn as nn
import torchvision.models as models
import torchvision
import torch.nn.functional as functional
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from PIL import Image
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
from torchvision.transforms.functional import resize
Use any dataset of color images.
I have the following code to get my train, test, and validation images from the folder “Dataset”:
path = "../Dataset/"
paths = np.array(glob.glob(path + "/*.jpg"))
rand_indices = np.random.permutation(len(paths)) # Number of images in dataset
train_indices, val_indices, test_indices = rand_indices[:3600], rand_indices[3600:4000], rand_indices[4000:]
train_paths = paths[train_indices]
val_paths = paths[val_indices]
test_paths = paths[test_indices]
Here is the data loader:
class ColorizeData(Dataset):
def __init__(self, paths):
self.input_transform = T.Compose([T.ToTensor(),
T.Resize(size=(256,256)),
T.Grayscale(),
T.Normalize((0.5), (0.5))
])
self.lab_transform = T.Compose([T.ToTensor(),
T.Resize(size=(256,256)),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.paths = paths
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
image = Image.open(self.paths[index]).convert("RGB")
input_image = self.input_transform(image)
image_lab = rgb2lab(image)
image_lab = self.lab_transform(image_lab)
image_l = image_lab[0, :, :]
image_ab = image_lab[1:3, :, :]
return (input_image.float(), image_ab.float(), image_l.float().reshape(1, 256, 256))
Here is the model:
class NetGen(nn.Module):
'''Generator'''
def __init__(self):
super(NetGen, self).__init__()
self.conv1 = nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False)
self.bnorm1 = nn.BatchNorm2d(64)
self.relu1 = nn.LeakyReLU(0.1)
self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
self.bnorm2 = nn.BatchNorm2d(128)
self.relu2 = nn.LeakyReLU(0.1)
self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)
self.bnorm3 = nn.BatchNorm2d(256)
self.relu3 = nn.LeakyReLU(0.1)
self.conv4 = nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False)
self.bnorm4 = nn.BatchNorm2d(512)
self.relu4 = nn.LeakyReLU(0.1)
self.conv5 = nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=False)
self.bnorm5 = nn.BatchNorm2d(512)
self.relu5 = nn.LeakyReLU(0.1)
self.deconv6 = nn.ConvTranspose2d(512, 512, 3, stride=2, padding=1, output_padding=1, bias=False)
self.bnorm6 = nn.BatchNorm2d(512)
self.relu6 = nn.ReLU()
self.deconv7 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1, bias=False)
self.bnorm7 = nn.BatchNorm2d(256)
self.relu7 = nn.ReLU()
self.deconv8 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1, bias=False)
self.bnorm8 = nn.BatchNorm2d(128)
self.relu8 = nn.ReLU()
self.deconv9 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False)
self.bnorm9 = nn.BatchNorm2d(64)
self.relu9 = nn.ReLU()
self.deconv10 = nn.ConvTranspose2d(64, 2, 3, stride=2, padding=1, output_padding=1, bias=False)
self.tanh = nn.Tanh()
def forward(self, x):
h = x
h = self.conv1(h)
h = self.bnorm1(h)
h = self.relu1(h)
pool1 = h
h = self.conv2(h)
h = self.bnorm2(h)
h = self.relu2(h)
pool2 = h
h = self.conv3(h)
h = self.bnorm3(h)
h = self.relu3(h)
pool3 = h
h = self.conv4(h)
h = self.bnorm4(h)
h = self.relu4(h)
pool4 = h
h = self.conv5(h)
h = self.bnorm5(h)
h = self.relu5(h)
h = self.deconv6(h)
h = self.bnorm6(h)
h = self.relu6(h)
h += pool4
h = self.deconv7(h)
h = self.bnorm7(h)
h = self.relu7(h)
h += pool3
h = self.deconv8(h)
h = self.bnorm8(h)
h = self.relu8(h)
h += pool2
h = self.deconv9(h)
h = self.bnorm9(h)
h = self.relu9(h)
h += pool1
h = self.deconv10(h)
h = self.tanh(h)
return h
class NetDis(nn.Module):
'''Discriminator'''
def __init__(self):
super(NetDis, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.1),
nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.1),
nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.1),
nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.1),
nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.1),
nn.Conv2d(512, 512, 8, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.1),
nn.Conv2d(512, 1, 1, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
Here is the weight init function:
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
Here is the training and validation code:
class Trainer:
def __init__(self, epochs, batch_size, learning_rate, num_workers):
self.epochs = epochs
self.batch_size = batch_size
self.learning_rate = learning_rate
self.num_workers = num_workers
self.train_paths = train_paths
self.val_paths = val_paths
self.real_label = 1
self.fake_label = 0
def train(self):
train_dataset = ColorizeData(paths=self.train_paths)
train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers,pin_memory=True, drop_last = True)
# Model
model_G = NetGen().to(device)
model_D = NetDis().to(device)
model_G.apply(weights_init)
model_D.apply(weights_init)
optimizer_G = torch.optim.Adam(model_G.parameters(),
lr=self.learning_rate, betas=(0.5, 0.999),
eps=1e-8, weight_decay=0)
optimizer_D = torch.optim.Adam(model_D.parameters(),
lr=self.learning_rate, betas=(0.5, 0.999),
eps=1e-8, weight_decay=0)
criterion = nn.BCELoss()
L1 = nn.L1Loss()
model_G.train()
model_D.train()
# train loop
for epoch in range(self.epochs):
print("Starting Training Epoch " + str(epoch + 1))
for i, data in enumerate(tqdm(train_dataloader)):
inputs, input_ab, input_l = data
inputs = inputs.to(device)
input_ab = input_ab.to(device)
input_l = input_l.to(device)
model_D.zero_grad()
label = torch.full((self.batch_size,), self.real_label, dtype=torch.float, device=device)
output = model_D(torch.cat([input_l, input_ab], dim=1))
errD_real = criterion(torch.squeeze(output), label)
errD_real.backward()
fake = model_G(input_l)
label.fill_(self.fake_label)
output = model_D(torch.cat([input_l, fake.detach()], dim=1))
errD_fake = criterion(torch.squeeze(output), label)
errD_fake.backward()
errD = errD_real + errD_fake
optimizer_D.step()
model_G.zero_grad()
label.fill_(self.real_label)
output = model_D(torch.cat([input_l, fake], dim=1))
errG = criterion(torch.squeeze(output), label)
errG_L1 = L1(fake.view(fake.size(0),-1), input_ab.view(input_ab.size(0),-1))
errG = errG + 100 * errG_L1
errG.backward()
optimizer_G.step()
print(f'Training: Epoch {epoch + 1} \t\t Discriminator Loss: {\
errD / len(train_dataloader)} \t\t Generator Loss: {\
errG / len(train_dataloader)}')
if (epoch + 1) % 1 == 0:
errD_val, errG_val, val_len = self.validate(model_D, model_G, criterion, L1)
print(f'Validation: Epoch {epoch + 1} \t\t Discriminator Loss: {\
errD_val / val_len} \t\t Generator Loss: {\
errG_val / val_len}')
torch.save(model_G.state_dict(), '../Results/Model_GAN/Generator/saved_model_' + str(epoch + 1) + '.pth')
torch.save(model_D.state_dict(), '../Results/Model_GAN/Discriminator/saved_model_' + str(epoch + 1) + '.pth')
def validate(self, model_D, model_G, criterion, L1):
model_G.eval()
model_D.eval()
with torch.no_grad():
valid_loss = 0.0
val_dataset = ColorizeData(paths=self.val_paths)
val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, drop_last = True)
for i, data in enumerate(val_dataloader):
inputs, input_ab, input_l = data
inputs = inputs.to(device)
input_ab = input_ab.to(device)
input_l = input_l.to(device)
label = torch.full((self.batch_size,), self.real_label, dtype=torch.float, device=device)
output = model_D(torch.cat([input_l, input_ab], dim=1))
errD_real = criterion(torch.squeeze(output), label)
fake = model_G(input_l)
label.fill_(self.fake_label)
output = model_D(torch.cat([input_l, fake.detach()], dim=1))
errD_fake = criterion(torch.squeeze(output), label)
errD = errD_real + errD_fake
label.fill_(self.real_label)
output = model_D(torch.cat([input_l, fake], dim=1))
errG = criterion(torch.squeeze(output), label)
errG_L1 = L1(fake.view(fake.size(0),-1), input_ab.view(input_ab.size(0),-1))
errG = errG + 100 * errG_L1
return errD, errG, len(val_dataloader)
Use this to run the pipeline:
trainer = Trainer(epochs = 100, batch_size = 64, learning_rate = 0.0002, num_workers = 2)
trainer.train()
I coded the training loop while referring to the PyTorch docs, and it worked when the generator did not have skip connections.
Thank you in advance!