Please see my full code:
#------------------------------------------------------------------------------#
DEPENDENCIES
#------------------------------------------------------------------------------#
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import os
from PIL import Image
import torchvision.models as models
from torchvision.models.segmentation import deeplabv3_resnet50
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
import random
from torch.utils.data import random_split
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ReduceLROnPlateau
import torch.nn.init as init
import random
from collections import deque
import numpy as np
#------------------------------------------------------------------------------#
Custom Dataset Class
#------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#
class CustomSegmentationDataset(Dataset):
def init(self, data_dir, labels_dir, image_transform=None, label_transform=None):
self.data_dir = data_dir
self.labels_dir = labels_dir
self.image_transform = image_transform
self.label_transform = label_transform
# List all image files in the data directory
self.image_paths = [os.path.join(data_dir, fname) for fname in os.listdir(data_dir) if fname.endswith('.jpg') or fname.endswith('.png')]
self.label_paths = [os.path.join(labels_dir, fname) for fname in os.listdir(labels_dir) if fname.endswith('.jpg') or fname.endswith('.png')]
# Ensure that the number of data and labels match
assert len(self.image_paths) == len(self.label_paths), "Mismatch between number of images and labels"
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('L')
label = Image.open(self.label_paths[idx]).convert('L') # Assuming labels are images
# Apply the transformations if any
if self.image_transform:
image = self.image_transform(image)
if self.label_transform:
label = self.label_transform(label)
return image, label
#------------------------------------------------------------------------------#
Dual ATTENTION (Channel Attention + Spatial Attention)
#--------------------------------------------------------------------------------#
class DualAttention(nn.Module):
def init(self, in_channels):
super(DualAttention, self).init()
self.channel_attention = ChannelAttention(in_channels)
self.spatial_attention = SpatialAttention()
def forward(self, x):
x = self.channel_attention(x)
x = self.spatial_attention(x)
return x
class ChannelAttention(nn.Module):
def init(self, in_channels):
super(ChannelAttention, self).init()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // 16, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // 16, in_channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
avg_out = self.avg_pool(x)
max_out = self.max_pool(x)
out = avg_out + max_out
out = out.view(out.size(0), -1) # Flatten to (batch_size, channels)
out = self.fc(out).view(out.size(0), out.size(1), 1, 1)
return x * out # Apply attention to the original feature map
class SpatialAttention(nn.Module):
def init(self):
super(SpatialAttention, self).init()
self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x_cat = torch.cat([avg_out, max_out], dim=1)
out = self.conv1(x_cat)
return x * self.sigmoid(out) # Apply attention to the original feature map
#------------------------------------------------------------------------------#
Custom Generator (Encoder Decoder Based)
#------------------------------------------------------------------------------#
class conv_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, padding_mode=‘reflect’)
self.bn1 = nn.BatchNorm2d(out_c, affine=True)
self.relu = nn.ReLU(inplace=True)
def forward(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.relu(x)
return x
class conv1_cross_1(nn.Module):
def init(self, in_channels, out_channels):
super().init()
self.conv1x1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
)
def forward(self, x):
return self.conv1x1(x)
class encoder_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.conv = conv_block(in_c, out_c)
self.attention = DualAttention(out_c) # Add attention mechanism
self.pool = nn.MaxPool2d((2, 2))
def forward(self, inputs):
x = self.conv(inputs)
p = self.pool(x)
return x, p
class decoder_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.up = nn.Sequential(
nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0),
nn.BatchNorm2d(out_c, affine=True),
nn.ReLU(inplace=True)
)
self.attention = DualAttention(out_c) # Add attention mechanism
def forward(self, inputs, skip):
x = self.up(inputs)
x = x + skip # Summation-based skip connection
x = self.attention(x) # Apply attention after the skip connection
return x
class classifier_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.outputs = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)
)
def forward(self, inputs):
x = self.outputs(inputs)
return x
class Generator(nn.Module):
def init(self):
super().init()
self.e1 = encoder_block(1, 96)
self.e2 = encoder_block(96, 96)
self.e3 = encoder_block(96, 96)
self.e4 = encoder_block(96, 96)
# Bottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(96),
nn.ReLU(inplace=True)
)
self.d1 = decoder_block(96, 96)
self.d2 = decoder_block(96, 96)
self.d3 = decoder_block(96, 96)
self.d4 = decoder_block(96, 96)
self.outputs = classifier_block(96, 3) # Output channels = 3 for 3 classes
def forward(self, inputs):
# Encoder
s1, p1 = self.e1(inputs)
s2, p2 = self.e2(p1)
s3, p3 = self.e3(p2)
s4, p4 = self.e4(p3)
# Bottleneck
b = self.bottleneck(p4)
# Decoder with skip connections
d1 = self.d1(b, s4)
d2 = self.d2(d1, s3)
d3 = self.d3(d2, s2)
d4 = self.d4(d3, s1)
# Output
x = self.outputs(d4) # Output shape: (batch_size, 3, height, width)
# print ('at generation level')
# print (x.shape)
return x
#------------------------------------------------------------------------------#
Custom Discriminator (encoder decoder based)
#------------------------------------------------------------------------------#
#-------------------------------------------------------------------------------#
custom pixelwise encoder-decoder based discriminator
#-------------------------------------------------------------------------------#
class EncoderBlock(nn.Module):
def init(self, in_channels, out_channels):
super(EncoderBlock, self).init()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True)
)
def forward(self, x):
return self.conv(x)
class DecoderBlock(nn.Module):
def init(self, in_channels, out_channels):
super(DecoderBlock, self).init()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x, skip):
x = self.conv(x)
x = x + skip # addition based skip connection
# x = torch.cat((x, skip), dim=1) # Skip connection
return x
class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()
# Encoder layers
self.e1 = EncoderBlock(3, 96)
self.e2 = EncoderBlock(96, 96)
self.e3 = EncoderBlock(96, 96)
self.e4 = EncoderBlock(96, 96)
# Bottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(96, 96, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(96),
nn.LeakyReLU(0.2, inplace=True)
)
# Decoder layers (for skip connections)
self.d1 = DecoderBlock(96, 96)
self.d2 = DecoderBlock(96, 96)
self.d3 = DecoderBlock(96, 96)
self.d4 = DecoderBlock(96, 96)
# Final output layer (pixel-wise)
self.outputs = nn.Conv2d(96, 1, kernel_size=1, stride=1, padding=0, bias=False) # Output 1 channel per pixel
def forward(self, x):
# Encoder forward pass
# print ("before discrinaation",x.shape)
s1 = self.e1(x)
s2 = self.e2(s1)
s3 = self.e3(s2)
s4 = self.e4(s3)
# print ('before bootleneck at disc level',s4.shape)
# Bottleneck (intermediate layer)
b = self.bottleneck(s4)
# print ('After bootleneck at disc level',b.shape)
# Decoder with skip connections
d1 = self.d1(b, s4)
d2 = self.d2(d1, s3)
d3 = self.d3(d2, s2)
d4 = self.d4(d3, s1)
# Final pixel-wise output (raw logits)
x = self.outputs(d4)
# print ('at discriminator level')
# print (x.shape)
return x
#-----------------------------------------------------------------------------#
#-------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#
Custom Segmentation Unet
#------------------------------------------------------------------------------#
class SegmentationModel(nn.Module):
def init(self, in_channels=1, out_channels=3):
super(SegmentationModel, self).init()
self.encoder1 = self.conv_block(in_channels, 96)
self.encoder2 = self.conv_block(96, 96)
self.encoder3 = self.conv_block(96, 96)
self.encoder4 = self.conv_block(96, 96)
self.bottleneck = self.conv_block(96, 96)
self.decoder4 = self.conv_block(96, 96)
self.decoder3 = self.conv_block(96, 96)
self.decoder2 = self.conv_block(96, 96)
self.decoder1 = self.conv_block(96, 96)
self.final_conv = nn.Conv2d(96, out_channels, kernel_size=1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(enc1)
enc3 = self.encoder3(enc2)
enc4 = self.encoder4(enc3)
bottleneck = self.bottleneck(enc4)
dec4 = self.decoder4(bottleneck)
dec3 = self.decoder3(dec4 + enc4)
dec2 = self.decoder2(dec3 + enc3)
dec1 = self.decoder1(dec2 + enc2)
x = self.final_conv(dec1 + enc1)
# print ('at segmentation')
# print (x.shape)
return x
#------------------------------------------------------------------------------#
Custom CycleGAN
#------------------------------------------------------------------------------#
class CycleGAN(nn.Module):
def init(self, num_classes):
super(CycleGAN, self).init()
self.G_A2B = Generator()
self.G_B2A = Generator()
self.D_A = Discriminator()
self.D_B = Discriminator()
self.Seg = SegmentationModel()
self.criterion_adv = nn.MSELoss()
self.criterion_cycle = nn.L1Loss()
self.criterion_identity = nn.L1Loss()
self.criterion_seg = nn.CrossEntropyLoss()
def forward(self, real_A, real_B, seg_A, seg_B):
fake_B = self.G_A2B(real_A)
fake_A = self.G_B2A(real_B)
D_A_fake = self.D_A(fake_A.detach())
D_B_fake = self.D_B(fake_B.detach())
cycle_A = self.G_B2A(fake_B)
cycle_B = self.G_A2B(fake_A)
identity_A = self.G_B2A(real_A)
identity_B = self.G_A2B(real_B)
D_A_real = self.D_A(real_A)
D_B_real = self.D_B(real_B)
seg_fake_B = self.Seg(fake_B)
seg_fake_A = self.Seg(fake_A)
return {
'fake_A': fake_A,
'fake_B': fake_B,
'cycle_A': cycle_A,
'cycle_B': cycle_B,
'identity_A': identity_A,
'identity_B': identity_B,
'D_A_real': D_A_real,
'D_A_fake': D_A_fake,
'D_B_real': D_B_real,
'D_B_fake': D_B_fake,
'seg_fake_A': seg_fake_A,
'seg_fake_B': seg_fake_B
}
#------------------------------------------------------------------------------#
IoU Calculation
#------------------------------------------------------------------------------#
def calculate_iou(pred, target, num_classes):
pred = pred.view(-1) # Flatten the prediction tensor
target = target.view(-1) # Flatten the target tensor
ious =
for c in range(num_classes):
pred_c = (pred == c) # Boolean mask for predicted class
target_c = (target == c) # Boolean mask for target class
intersection = (pred_c & target_c).sum().float() # Intersection
union = (pred_c | target_c).sum().float() # Union
# Handle the case where union is zero (no pixels for this class)
if union == 0:
ious.append(0.0) # IoU is 0 if there are no pixels for this class
else:
iou = intersection / union
ious.append(iou)
# Compute the mean IoU across all classes
mean_iou = torch.mean(torch.tensor(ious))
return mean_iou # Mean IoU is guaranteed to be in [0, 1]
#-------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#
Training Loop
#------------------------------------------------------------------------------#
def train_cyclegan(real_A, real_B, seg_A, seg_B, model, optimizer_G, optimizer_D, epoch, batch_idx):
model.train()
device = next(model.parameters()).device
# Move data to the correct device
real_A = real_A.to(device) # No need for requires_grad=True for inputs
real_B = real_B.to(device) # No need for requires_grad=True for inputs
seg_A = seg_A.long().to(device) if seg_A is not None else None
seg_B = seg_B.long().to(device) if seg_B is not None else None
# Forward pass
outputs = model(real_A, real_B, seg_A, seg_B)
# Generator losses
loss_G_A2B = model.criterion_adv(outputs['D_B_fake'], torch.ones_like(outputs['D_B_fake']).to(device))
loss_G_B2A = model.criterion_adv(outputs['D_A_fake'], torch.ones_like(outputs['D_A_fake']).to(device))
loss_G_adv = loss_G_A2B + loss_G_B2A
loss_cycle_A = model.criterion_cycle(outputs['cycle_A'], real_A)
loss_cycle_B = model.criterion_cycle(outputs['cycle_B'], real_B)
loss_cycle = loss_cycle_A + loss_cycle_B
loss_identity_A = model.criterion_identity(outputs['identity_A'], real_A)
loss_identity_B = model.criterion_identity(outputs['identity_B'], real_B)
loss_identity = loss_identity_A + loss_identity_B
# Segmentation loss and IoU calculation
loss_seg = 0.0
mean_iou = 0.0
if seg_A is not None and seg_B is not None:
seg_fake_A = outputs['seg_fake_A'].float()
seg_fake_B = outputs['seg_fake_B'].float()
seg_A = seg_A.squeeze(1).long() # Ensure it is in shape (batch_size, height, width)
seg_B = seg_B.squeeze(1).long() # Ensure it is in shape (batch_size, height, width)
loss_seg_A = model.criterion_seg(seg_fake_A, seg_A)
loss_seg_B = model.criterion_seg(seg_fake_B, seg_B)
loss_seg = loss_seg_A + loss_seg_B
pred_A = torch.argmax(outputs['seg_fake_A'], dim=1) # Get the predicted classes
pred_B = torch.argmax(outputs['seg_fake_B'], dim=1) # Same for B
mean_iou_A = calculate_iou(pred_A, seg_A, num_classes=3)
mean_iou_B = calculate_iou(pred_B, seg_B, num_classes=3)
mean_iou = (mean_iou_A + mean_iou_B) / 2
# Total generator loss
loss_G = loss_G_adv + 10.0 * loss_cycle + 5.0 * loss_identity + 1.0 * loss_seg
# Discriminator losses
loss_D_A = model.criterion_adv(outputs['D_A_real'], torch.ones_like(outputs['D_A_real']).to(device)) + \
model.criterion_adv(outputs['D_A_fake'], torch.zeros_like(outputs['D_A_fake']).to(device))
loss_D_B = model.criterion_adv(outputs['D_B_real'], torch.ones_like(outputs['D_B_real']).to(device)) + \
model.criterion_adv(outputs['D_B_fake'], torch.zeros_like(outputs['D_B_fake']).to(device))
loss_D = loss_D_A + loss_D_B
# Backpropagation
optimizer_G.zero_grad()
loss_G.backward(retain_graph=True) # Retain the graph for the next backward pass
optimizer_G.step()
optimizer_D.zero_grad()
loss_D.backward() # No need to retain the graph here
optimizer_D.step()
return loss_G.item(), loss_D.item(), mean_iou
#----------------------------------------------------------------------------------#
def evaluate_cyclegan(real_A, real_B, seg_A, seg_B, model):
# Forward pass
outputs = model(real_A, real_B, seg_A, seg_B)
# Generator losses
loss_G_A2B = model.criterion_adv(outputs['D_B_fake'], torch.ones_like(outputs['D_B_fake']).to(device))
loss_G_B2A = model.criterion_adv(outputs['D_A_fake'], torch.ones_like(outputs['D_A_fake']).to(device))
loss_G_adv = loss_G_A2B + loss_G_B2A
loss_cycle_A = model.criterion_cycle(outputs['cycle_A'], real_A)
loss_cycle_B = model.criterion_cycle(outputs['cycle_B'], real_B)
loss_cycle = loss_cycle_A + loss_cycle_B
loss_identity_A = model.criterion_identity(outputs['identity_A'], real_A)
loss_identity_B = model.criterion_identity(outputs['identity_B'], real_B)
loss_identity = loss_identity_A + loss_identity_B
# Segmentation loss and IoU calculation
loss_seg = 0.0
mean_iou = 0.0
if seg_A is not None and seg_B is not None:
seg_fake_A = outputs['seg_fake_A'].float()
seg_fake_B = outputs['seg_fake_B'].float()
seg_A = seg_A.squeeze(1).long()
seg_B = seg_B.squeeze(1).long()
loss_seg_A = model.criterion_seg(seg_fake_A, seg_A)
loss_seg_B = model.criterion_seg(seg_fake_B, seg_B)
loss_seg = loss_seg_A + loss_seg_B
pred_A = torch.argmax(outputs['seg_fake_A'], dim=1)
pred_B = torch.argmax(outputs['seg_fake_B'], dim=1)
mean_iou_A = calculate_iou(pred_A, seg_A, num_classes=3)
mean_iou_B = calculate_iou(pred_B, seg_B, num_classes=3)
mean_iou = (mean_iou_A + mean_iou_B) / 2
# Total generator loss
loss_G = loss_G_adv + 10.0 * loss_cycle + 5.0 * loss_identity + 1.0 * loss_seg
# Discriminator losses (but we don’t need to calculate gradients here)
loss_D_A = model.criterion_adv(outputs['D_A_real'], torch.ones_like(outputs['D_A_real']).to(device)) + \
model.criterion_adv(outputs['D_A_fake'], torch.zeros_like(outputs['D_A_fake']).to(device))
loss_D_B = model.criterion_adv(outputs['D_B_real'], torch.ones_like(outputs['D_B_real']).to(device)) + \
model.criterion_adv(outputs['D_B_fake'], torch.zeros_like(outputs['D_B_fake']).to(device))
loss_D = loss_D_A + loss_D_B
return loss_G.item(), loss_D.item(), mean_iou
#------------------------------------------------------------------------------#
INITIALISATION OF WEIGHTS Zero-Centered Initialization
#------------------------------------------------------------------------------#
def zero_centered_init(m):
“”"
Apply zero-centered initialization to the weights of a layer.
“”"
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
nn.init.normal_(m.weight, mean=0.0, std=0.02) # Zero-centered initialization
if m.bias is not None:
nn.init.zeros_(m.bias) # Initialize biases to zero
#------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#
Main Execution
#------------------------------------------------------------------------------#
if name == “main”:
#------------------------------------------------------------------------------#
Hyperparameters
#------------------------------------------------------------------------------#
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
num_classes = 3
num_epochs = 1
batch_size = 1
learning_rate_G = 0.0001
learning_rate_D = 0.0001
replay_buffer_size = 50 # Size of the replay buffer
#------------------------------------------------------------------------------#
Initialize model and optimizers
#------------------------------------------------------------------------------#
model = CycleGAN(num_classes).to(device)
model.apply(zero_centered_init)
# model.apply(init_weights)
optimizer_G = optim.Adam(list(model.G_A2B.parameters()) + list(model.G_B2A.parameters()), lr=learning_rate_G, betas=(0.5, 0.999))
optimizer_D = optim.Adam(list(model.D_A.parameters()) + list(model.D_B.parameters()), lr=learning_rate_D, betas=(0.5, 0.999))
#---------------------------------------------------------#
# SCHEDULERS
#----------------#
scheduler_G = StepLR(optimizer_G, step_size=30, gamma=0.1)
scheduler_D = StepLR(optimizer_D, step_size=30, gamma=0.1)
#------------------------------------------------------------------------------#
Dataset and DataLoader
#------------------------------------------------------------------------------#
data = r’C:\Users\Idrees Bhat\Desktop\reduced_dataset\source’
labels = r’C:\Users\Idrees Bhat\Desktop\reduced_dataset\target’
image_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5]),])
label_transform = transforms.Compose([transforms.ToTensor()])
# Load the full dataset
dataset = CustomSegmentationDataset(data,labels,image_transform=image_transform, label_transform=label_transform )
# Set the split ratios
train_size = int(0.8 * len(dataset)) # 80% for training
val_size = int(0.1 * len(dataset)) # 10% for validation
test_size = len(dataset) - train_size - val_size # Remaining 10% for testing
# test_size = len(dataset) - train_size
# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
# Create DataLoader for each split
train_loader = DataLoader(train_dataset, batch_size= batch_size, shuffle=True)
val_loader = DataLoader(val_dataset,batch_size= batch_size, shuffle=False)
test_loader = DataLoader(test_dataset,batch_size= batch_size, shuffle=False)
# print ('')
# print (type(train_data))
#-------------------------------------------------------------------------------------------------------#
# tracking losses
#--------------------------------------------------------------------------------------------------------#
# Initialize lists to store values
gen_losses = []
disc_losses = []
iou_values = []
lr_G_values = []
lr_D_values = []
#-----------------------------------------------------------------------------------------------------------#
TRAINING LOOP
#------------------------------------------------------------------------------------------------------------#
# Training Loop
for epoch in range(num_epochs):
total_gen_loss = 0.0
total_disc_loss = 0.0
total_iou = 0.0
for i, (real_A, real_B) in enumerate(train_loader):
real_A = real_A.to(device)
real_B = real_B.to(device)
# Forward pass to generate fake images
fake_B = model.G_A2B(real_A)
fake_A = model.G_B2A(real_B)
# Pass real_A and real_B as segmentation maps
loss_G, loss_D, mean_iou = train_cyclegan(real_A, real_B, real_B, real_B, model, optimizer_G, optimizer_D,epoch, i)
# Get intermediate outputs from the model
outputs = model(real_A, real_B, real_B, real_B)
if i % 5 == 0:
print(f"Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(train_loader)}] Loss_G: {loss_G:.4f} Loss_D: {loss_D:.4f} Mean IoU: {mean_iou:.4f}")
scheduler_G.step()
scheduler_D.step()