I am building a GAN with input image of 4 channels and output of 1 channel. I have a Unet on the Generator side. Here is my code:
MEAN = (0.5, 0.5, 0.5, 0.5,)
STD = (0.5, 0.5, 0.5, 0.5,)
RESIZE = 80 #96 # 128 256
img_size = RESIZE
def argument_parser(img):
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", type=str, default="opencv_logo.png", help="path to the input image")
args = vars(ap.parse_args())
class Transform():
def __init__(self, resize=RESIZE, mean=MEAN, std=STD):
if resize> 128:
self.data_transform = transforms.Compose([
#transforms.Resize((resize*2, resize*2)),
transforms.Resize((resize, resize)),
transforms.CenterCrop(resize),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
else:
self.data_transform = transforms.Compose([
#transforms.Resize((resize*2, resize*2)),
#transforms.Resize((resize, resize)),
transforms.CenterCrop(resize),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
def __call__(self, img: Image.Image):
return self.data_transform(img)
class Dataset(object):
def __init__(self, files: List[str]):
self.files = files
self.trasformer = Transform()
def _separate(self, img) -> Tuple[Image.Image, Image.Image]:
img = np.array(img, dtype=np.uint8)
h, w, _ = img.shape
w = int(w / 2)
return Image.fromarray(img[:, w:, :]), Image.fromarray(img[:, :w, :])
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
img = Image.open(self.files[idx])
input, output = self._separate(img)
input_tensor = self.trasformer(input)
# Creating the output tensor with only the first channel from the input
output_t = self.trasformer(output)
output_tensor = output_t[0:1, :, :]
print('input tensor shape :', input_tensor.shape)
print('output tensor shape :', output_tensor.shape)
return input_tensor, output_tensor
def __len__(self):
return len(self.files)
def show_img_sample(img: torch.Tensor, img1: torch.Tensor):
fig, axes = plt.subplots(1, 2, figsize=(15, 8))
ax = axes.ravel()
ax[0].imshow(img.permute(1, 2, 0))
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].set_title("input image", c="g")
ax[1].imshow(img1.permute(1, 2, 0))
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].set_title("label image", c="g")
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
class DoubleConv(nn.Module):
in_channels = 4 #5
out_channels = 1 #5
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
# 3 = kernel size
# 1 = stride
# 1 = padding
# setting bias = False b/c we are using batch norm
# will result in a same covolution such that the image size of the input = ouput size
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True), #(SpectralNorm),
nn.BatchNorm2d(out_channels),
#nn.ReLU(inplace=True),
nn.LeakyReLU(0.2, inplace=True),
# second convolution
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias = True), #(SpectralNorm),
nn.BatchNorm2d(out_channels),
#nn.ReLU(inplace=True),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, x):
return self.conv(x) # this calls the DoubleConv we defined above (2 Conv2d steps)
class UpConv(nn.Module):
in_channels = 4 #5
out_channels = 1
def __init__(self, in_channels, out_channels):
super(UpConv, self).__init__()
self.conv = nn.Sequential(
# 3 = kernel size
# 1 = stride
# 1 = padding
# setting bias = False b/c we are using batch norm
# will result in a same covolution such that the image size of the input = ouput size
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True), #(SpectralNorm),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
# second convolution
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias = True), #(SpectralNorm),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
)
def forward(self, x):
return self.conv(x) # this calls the DoubleConv we defined above (2 Conv2d steps)
class Generator(nn.Module):
#class UNET(nn.Module):
in_channels = 4 #5
out_channels = 1 #4 #5
def __init__(self, in_channels=in_channels, out_channels=out_channels, features=[64,128, 256, 512,],):
#super(UNET, self).__init__()
super(Generator, self).__init__()
self.ups = nn.ModuleList()
self.downs = nn.ModuleList()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Down part of the UNET
self.downs1=DoubleConv(in_channels, 64)
self.downs2=DoubleConv(64, 128)
self.downs3=DoubleConv(128, 256)
self.downs4=DoubleConv(256, 512)
#self.bottleneck = DoubleConv(features[-1], features[-1]*2)
self.bottleneck = DoubleConv(512, 1024)
self.ups4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2,padding=0) #(SpectralNorm( ))
self.ups_conv4=UpConv(1024, 512)
self.ups3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2,padding=0) #(SpectralNorm)
self.ups_conv3=UpConv(512, 256)
self.ups2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2,padding=0) #(SpectralNorm)
self.ups_conv2=UpConv(256, 128)
self.ups1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2,padding=0) #(SpectralNorm)
self.ups_conv1=UpConv(128, 64)
'''for feature in reversed(features):
self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2,))
self.ups.append(DoubleConv(feature*2, feature))'''
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
skip_connections = []
x1 = self.downs1(x)
skip_connections.append(x1)
x1 = self.pool(x1)
x2 = self.downs2(x1)
skip_connections.append(x2)
x2 = self.pool(x2)
x3 = self.downs3(x2)
skip_connections.append(x3)
x3 = self.pool(x3)
x4 = self.downs4(x3)
skip_connections.append(x4)
x4 = self.pool(x4)
'''for down in self.downs:
x = down(x)
skip_connections.append(x)
x = self.pool(x)'''
x5 = self.bottleneck(x4)
d6 = self.ups4(x5)
skip_connection4 = skip_connections[3]
print("Size of skip_connection4:", skip_connection4.size())
print("Size of d6 before concat:", d6.size())
d6 = torch.cat((skip_connection4, d6), dim=1)
print("Size of d6 after concat:", d6.size())
d6 = self.ups_conv4(d6)
d7 = self.ups3(d6)
skip_connection3 = skip_connections[2] # x3
print("Size of skip_connection3:", skip_connection3.size())
print("Size of d7 before concat:", d7.size())
d7 = torch.cat((skip_connection3, d7), dim=1)
print("Size of d7 after concat:", d7.size())
d7 = self.ups_conv3(d7)
d8 = self.ups2(d7)
skip_connection2 = skip_connections[1] # x4
print("Size of skip_connection2:", skip_connection2.size())
print("Size of d8 before concat:", d8.size())
d8 = torch.cat((skip_connection2, d8), dim=1)
print("Size of d8 after concat:", d8.size())
d8 = self.ups_conv2(d8)
d9 = self.ups1(d8)
skip_connection1 = skip_connections[0]
print("Size of skip_connection1:", skip_connection1.size())
print("Size of d9 before concat:", d9.size())
d9 = torch.cat((skip_connection1, d9), dim=1)
print("Size of d9 after concat:", d9.size())
d9 = self.ups_conv1(d9)
return torch.tanh(self.final_conv(d9))
# write a test of the UNET
def test():
in_channels = 4 #5
out_channels = 1# 5
#x = torch.randn((10, 5, 160, 160)) # batch size = 3, channels = 1, image = 160 x 160 (divisable by 16)
x = torch.randn((10, 4, 64, 64))
#model = UNET(in_channels = in_channels, out_channels=out_channels)
model = Generator(in_channels = in_channels, out_channels=out_channels)
preds = model(x)
print(preds.shape)
print(x.shape)
# assert preds.shape == x.shape
if __name__ == "__main__":
test()
class Discriminator(nn.Module):
def __init__(self, in_channels=1,out_channels=4,hidden_channels=32):
super(Discriminator, self).__init__()
self.layer1 = self.conv2relu(in_channels*2, hidden_channels//2, 5, cnt=1)
self.layer2 = self.conv2relu(hidden_channels//2, hidden_channels, pool_size= None)
self.layer3 = self.conv2relu(hidden_channels, hidden_channels*2, pool_size=None)
self.layer4 = self.conv2relu(hidden_channels*2, hidden_channels*4, pool_size=None)
self.layer5 = self.conv2relu(hidden_channels*4, hidden_channels*8, pool_size=None)
self.layer6 = self.conv2relu(hidden_channels*8, hidden_channels*16, pool_size=None)
self.layer7 = self.conv2relu(hidden_channels*16, hidden_channels*32, pool_size=None)
self.layer8 = nn.Conv2d(hidden_channels*32, 1, kernel_size=1) #()SpectralNorm(
def conv2relu(self, in_c, out_c, kernel_size=3, pool_size=None, cnt=2):
layer = []
for i in range(cnt):
if i == 0 and pool_size != None:
# Down width and height
layer.append(nn.AvgPool2d(pool_size)) #(SpectralNorm)
# Down channel size
layer.append(nn.Conv2d(in_c if i == 0 else out_c, # (SpectralNorm
out_c,
kernel_size,
padding=(kernel_size-1)//2))
layer.append(nn.BatchNorm2d(out_c))
layer.append(nn.LeakyReLU(0.2, inplace=True))
return nn.Sequential(*layer)
def forward(self, x, x1):
x = torch.cat((x, x1), dim=1)
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out5 = self.layer5(out4)
out6 = self.layer6(out5)
out7 = self.layer7(out6)
return torch.sigmoid(self.layer8(out7))
def test_D():
in_channels = 4
out_channels = 1
#x = torch.randn((10, 5, 160, 160)) # batch size = 3, channels = 1, image = 160 x 160 (divisable by 16)
x = torch.randn((1, 4, 30, 30))
x1 = torch.randn((1, 4, 30, 30))
#model = UNET(in_channels = in_channels, out_channels=out_channels)
model = Discriminator(in_channels = in_channels, out_channels=out_channels)
preds = model(x, x1)
print(f'pred tensor dimension: {preds.shape}')
print(f' image dimension: {x.shape} ')
#assert preds.shape == x.shape
if __name__ == "__main__":
test_D()
import re
import random
current_epoch = 0
# num_epoch = 2
# batch=2
# critics = 3
#def train_fn(train_dl, G, D, criterion_bce, criterion_smoothL1, optimizer_g, optimizer_d):
def train_fn(train_dl, G, D, GAN_loss, L1_loss, optimizer_g, optimizer_d):
#for epoch in range (current_epoch, num_epoch):
G.train() #.to(device)
D.train() #.to(device)
LAMBDA_penalty = 10.0
lambda_pixel = 10.0
total_loss_g, total_loss_d = [],[]
#for i, batch in enumerate(tqdm(train_dl)):
#input_img = Variable(batch[0].type(Tensor))
#real_img = Variable(batch[1].type(Tensor))
for i, (input_img, real_img) in enumerate(tqdm(train_dl)):
real_A = input_img.to(device)
real_B = real_img.to(device)
# ------------------
# Train Generators
# ------------------
optimizer_g.zero_grad()
# GAN loss
fake_B = G(real_A)
pred_fake_g = D(fake_B, real_A).squeeze() # ROkwen added .squeeze()
pred_real_label = D(real_B, real_A).squeeze()
# Adversarial ground truths
# Allow real image label to range between 0.8 and 0.9
real_label = random.randrange(9, 10,1)/9.0
real_label = torch.tensor(real_label)
real_label = real_label.expand_as(pred_real_label).to(device)
#real_label = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
#print(f' real label value: {real_label.mean()}')
fake_label = torch.tensor(0.0)
fake_label = fake_label.expand_as(pred_fake_g).to(device)
#print(f' fake label value: {fake_label.mean()}')
#####
# Adversarial loss
#loss_g_adv = GAN_loss(pred_fake_g, real_label) # nn.BCEWithLogitsLoss()
loss_g_adv = LSGAN_loss(pred_fake_g, real_label) # nn.MSELoss()
#print("Adversarial loss:", loss_g_adv)
# Dice loss
#loss_g_dice = calc_loss(fake_B, real_B)
#loss_g_adv = loss_g_dice
#print(f'Dice loss {loss_g_dice}')
# Hinge loss
#loss_g_adv = - pred_fake_g.mean()
#print(f'Ganloss {loss_g_adv}')
# Pixel-wise loss
loss_pixel = L1_loss(fake_B, real_B)
#loss_pixel = L1_loss(torch.tanh(fake_B), real_B)
#torch.tanh(self.final_conv(d9))
#print("pixel loss:", loss_pixel)
#print(f'L1 loss {loss_pixel}')
# Total loss
loss_g = loss_g_adv + lambda_pixel * loss_pixel
#loss_g = loss_g_dice
#print(f'Total GAN loss {loss_g}')
total_loss_g.append(loss_g.item())
loss_g.backward()
optimizer_g.step()
#optimizer_d.zero_grad()
#optimizer_g.zero_grad()
# ------------------
# Train Discriminator
# ------------------
for _ in range (n_critic):
optimizer_d.zero_grad()
# Calculate Discriminator Losses
# Real loss
pred_real = D(real_B, real_A).squeeze() # ROkwen added .squeeze()
#real_img = input_img.to(device)
#input_img = real_img.to(device)
loss_d_real = GAN_loss(pred_real, real_label)
#print("real D loss:", loss_d_real)
#loss_d_real = torch.nn.ReLU()(1.0 - pred_real).mean()
#print(f'real discriminator loss {loss_d_real}')
# Fake loss
#fake_B = G(real_A)
pred_fake_d = D(fake_B.detach(), real_A).squeeze() # ROkwen added .squeeze()
loss_d_fake = GAN_loss(pred_fake_d, fake_label)
#print("fake D loss:", loss_d_fake)
#print(f'fake discriminator loss {loss_d_fake}')
loss_d = (loss_d_real + loss_d_fake)*0.5 # Adversarial D loss # #
#loss_d = loss_d_real + loss_d_fake
total_loss_d.append(loss_d.item())
#print(f'Discriminator loss: {loss_d}')
# Backward + Optimize
loss_d.backward()
optimizer_d.step()
# clear_output()
real_A = torch.squeeze(real_A, 0)
real_B = torch.squeeze(real_B, 0)
fake_B = torch.squeeze(fake_B, 0)
# 2, 4, 80, 80
real_A = real_A[0,:,:,:]
real_B = real_B[0,0,:,:]
fake_B = fake_B[0,0,:,:]
print(f'target images shape 2 {real_A.shape}')
'''print(f'target images shape 2 {real_A.shape}')
real_A = torch.reshape(real_A, (real_A.shape[0],real_A.shape[2],real_A.shape[2])).detach().cpu()
#print(f'input_img images shape {real_A.shape}')
fake_B = torch.reshape(fake_B, (fake_B.shape[0],fake_B.shape[2],fake_B.shape[2])).detach().cpu()
real_B = torch.reshape(real_B, (real_B.shape[0],real_B.shape[2],real_B.shape[2])).detach().cpu()
#print(f'target images shape 2 {real_B.shape}')
show_img_sample3(real_A, fake_B, real_B)
'''
#print(f'target images shape 2 {fake_B.shape[1:]}')
#fake_B = torch.reshape(fake_B, (fake_B.shape[1],fake_B.shape[2],fake_B.shape[2])).detach().cpu()
#real_B = torch.reshape(real_B, (real_B.shape[1],real_B.shape[2],real_B.shape[2])).detach().cpu()
#show_img_sample_training(fake_B,real_B)
#print(f'gradient penalty loss: {gradient_penalty}')
#print(f'Discriminator loss: {loss_d}')
scheduler_g.step()
scheduler_d.step()
lr_D = scheduler_d.get_last_lr()[0]
lr_G = scheduler_g.get_last_lr()[0]
#lr_D = lr
#lr_G = lr
#return mean(total_loss_g), mean(total_loss_d), input_img.detach().cpu(), fake_img.detach().cpu(), real_img.detach().cpu(), lr_D, lr_G #.to(device) #.cpu()
return mean(total_loss_g), mean(total_loss_d), real_A.detach().cpu(), fake_B.detach().cpu(), real_B.detach().cpu(), lr_D, lr_G #.to(device) #.cpu()
#return mean(total_loss_g), mean(total_loss_d), real_A.detach().cpu(), fake_B.detach().cpu(), real_B.detach().cpu() #.to(device) #.cpu()
def saving_img(input_img, fake_img, real_img, e):
os.makedirs("generated", exist_ok=True)
save_image(input_img, f"generated/input{str(e)}.tiff", value_range=(-1.0, 1.0), normalize=True)
save_image(fake_img, f"generated/fake{str(e)}.tiff", value_range=(-1.0, 1.0), normalize=True)
save_image(real_img, f"generated/real{str(e)}.tiff", value_range=(-1.0, 1.0), normalize=True)
# Changed 'range' to 'value_range' ==> 'range' will soon be deprecated.
def saving_logs(result):
with open("train.pkl", "wb") as f:
pickle.dump([result], f)
def saving_model(D, G, e):
os.makedirs("weight", exist_ok=True)
torch.save(G.state_dict(), f"weight/G{str(e+1)}.pth")
torch.save(D.state_dict(), f"weight/D{str(e+1)}.pth")
def show_losses(g, d, rmse_T):
fig, axes = plt.subplots(1, 3, figsize=(21,6)) # (14, 6)
ax = axes.ravel()
ax[0].plot(np.arange(len(d)).tolist(), d)
ax[0].set_title("Discriminator Loss")
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[1].plot(np.arange(len(g)).tolist(), g)
ax[1].set_title("Generator Loss")
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Loss')
ax[2].plot(np.arange(len(rmse_T)).tolist(), rmse_T)
ax[2].set_title("Root Mean Square Error")
ax[2].set_xlabel('Epochs')
ax[2].set_ylabel('RMSE')
plt.show()