Semantic Segmentation with Attention based CycleGAN

Hi, Please read my scenario : I am having Original Images as Gray scale Images, and Target Images as gray scales images. However, Target Images are segmentation maps [background=1, segmentation object=2, another segmentation object=3]. I mean to say Target images are segmentation maps whos pixels range from 1 to 3 but Target or labelled images are also Gray scale images. I am designing custom Pixelwise Generator and Pixelwise Discriminators for my Model. Note I am using Cross entropy loss for Segmentation loss (for multi-class semantic segmentation). Therefore, I have to send that loss function gradients as Raw logits : Please see my code

class CustomSegmentationDataset(Dataset):
def init(self, data_dir, labels_dir, image_transform=None, label_transform=None):
self.data_dir = data_dir
self.labels_dir = labels_dir
self.image_transform = image_transform
self.label_transform = label_transform

    # List all image files in the data directory
    self.image_paths = [os.path.join(data_dir, fname)   for fname in os.listdir(data_dir) if fname.endswith('.jpg') or fname.endswith('.png')]
    self.label_paths = [os.path.join(labels_dir, fname) for fname in os.listdir(labels_dir) if fname.endswith('.jpg') or fname.endswith('.png')]

    # Ensure that the number of data and labels match
    assert len(self.image_paths) == len(self.label_paths), "Mismatch between number of images and labels"

def __len__(self):
    return len(self.image_paths)

def __getitem__(self, idx):
    image = Image.open(self.image_paths[idx]).convert('L')
    label = Image.open(self.label_paths[idx]).convert('L')  # Assuming labels are images
    
    # Apply the transformations if any
    if self.image_transform:
        image = self.image_transform(image)
    
    if self.label_transform:
        label = self.label_transform(label)
    return image, label   

#------------------------------------------------------------------------------#

Dual ATTENTION (Channel Attention + Spatial Attention)

#--------------------------------------------------------------------------------#
class DualAttention(nn.Module):
def init(self, in_channels):
super(DualAttention, self).init()
self.channel_attention = ChannelAttention(in_channels)
self.spatial_attention = SpatialAttention()

def forward(self, x):
    x = self.channel_attention(x)
    x = self.spatial_attention(x)
    return x

class ChannelAttention(nn.Module):
def init(self, in_channels):
super(ChannelAttention, self).init()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // 16, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // 16, in_channels, bias=False),
nn.Sigmoid()
)

def forward(self, x):
    avg_out = self.avg_pool(x)
    max_out = self.max_pool(x)
    out = avg_out + max_out
    out = out.view(out.size(0), -1)  # Flatten to (batch_size, channels)
    out = self.fc(out).view(out.size(0), out.size(1), 1, 1)
    return x * out  # Apply attention to the original feature map

class SpatialAttention(nn.Module):
def init(self):
super(SpatialAttention, self).init()
self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
    avg_out = torch.mean(x, dim=1, keepdim=True)
    max_out, _ = torch.max(x, dim=1, keepdim=True)
    x_cat = torch.cat([avg_out, max_out], dim=1)
    out = self.conv1(x_cat)
    return x * self.sigmoid(out)  # Apply attention to the original feature map

#------------------------------------------------------------------------------#

Custom Generator (Encoder Decoder Based)

#------------------------------------------------------------------------------#
class conv_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, padding_mode=‘reflect’)
self.bn1 = nn.BatchNorm2d(out_c, affine=True)
self.relu = nn.ReLU(inplace=True)

def forward(self, inputs):
    x = self.conv1(inputs)
    x = self.bn1(x)
    x = self.relu(x)
    return x

class conv1_cross_1(nn.Module):
def init(self, in_channels, out_channels):
super().init()
self.conv1x1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
)

def forward(self, x):
    return self.conv1x1(x)

class encoder_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.conv = conv_block(in_c, out_c)
self.attention = DualAttention(out_c) # Add attention mechanism
self.pool = nn.MaxPool2d((2, 2))

def forward(self, inputs):
    x = self.conv(inputs)
    p = self.pool(x)
    return x, p

class decoder_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.up = nn.Sequential(
nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0),
nn.BatchNorm2d(out_c, affine=True),
nn.ReLU(inplace=True)
)
self.attention = DualAttention(out_c) # Add attention mechanism

def forward(self, inputs, skip):
    x = self.up(inputs)
    x = x + skip  # Summation-based skip connection
    x = self.attention(x)  # Apply attention after the skip connection
    return x

class classifier_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.outputs = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)
)

def forward(self, inputs):
    x = self.outputs(inputs)
    return x

class Generator(nn.Module):
def init(self):
super().init()
self.e1 = encoder_block(1, 96)
self.e2 = encoder_block(96, 96)
self.e3 = encoder_block(96, 96)
self.e4 = encoder_block(96, 96)

    # Bottleneck
    self.bottleneck = nn.Sequential(
        nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(96),
        nn.ReLU(inplace=True)
    )
    
    self.d1 = decoder_block(96, 96)
    self.d2 = decoder_block(96, 96)
    self.d3 = decoder_block(96, 96)
    self.d4 = decoder_block(96, 96)
    self.outputs = classifier_block(96, 3)  # Output channels = 3 for 3 classes

def forward(self, inputs):  
    
    # Encoder
    s1, p1 = self.e1(inputs)
    s2, p2 = self.e2(p1)
    s3, p3 = self.e3(p2)
    s4, p4 = self.e4(p3)
    
    # Bottleneck
    b = self.bottleneck(p4)
    
    # Decoder with skip connections
    d1 = self.d1(b, s4)
    d2 = self.d2(d1, s3)
    d3 = self.d3(d2, s2)
    d4 = self.d4(d3, s1)
    
    # Output
    x = self.outputs(d4)  # Output shape: (batch_size, 3, height, width)
    return x

#------------------------------------------------------------------------------#

Custom Discriminator (encoder decoder based)

#------------------------------------------------------------------------------#
#-------------------------------------------------------------------------------#

custom pixelwise encoder-decoder based discriminator

#-------------------------------------------------------------------------------#
class EncoderBlock(nn.Module):
def init(self, in_channels, out_channels):
super(EncoderBlock, self).init()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True)
)

def forward(self, x):
    return self.conv(x)

class DecoderBlock(nn.Module):
def init(self, in_channels, out_channels):
super(DecoderBlock, self).init()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x, skip):
    x = self.conv(x)
    x = x + skip  # addition based skip connection
    # x = torch.cat((x, skip), dim=1)  # Skip connection
    return x

class Discriminator(nn.Module):
def init(self, in_channels=3):
super(Discriminator, self).init()

    # Encoder layers
    self.e1 = EncoderBlock(in_channels, 96)
    self.e2 = EncoderBlock(96, 96)
    self.e3 = EncoderBlock(96, 96)
    self.e4 = EncoderBlock(96, 96)
    
    # Bottleneck
    self.bottleneck = nn.Sequential(
        nn.Conv2d(96, 96, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.2, inplace=True)
    )
    
    # Decoder layers (for skip connections)
    self.d1 = DecoderBlock(96, 96)
    self.d2 = DecoderBlock(96, 96)
    self.d3 = DecoderBlock(96, 96)
    self.d4 = DecoderBlock(96, 96)
    
    # Final output layer (pixel-wise)
    self.outputs = nn.Conv2d(96, 1, kernel_size=1, stride=1, padding=0, bias=False)  # Output 1 channel per pixel

def forward(self, x):
    # Encoder forward pass
    s1 = self.e1(x)
    s2 = self.e2(s1)
    s3 = self.e3(s2)
    s4 = self.e4(s3)
    
    # Bottleneck (intermediate layer)
    b = self.bottleneck(s4)
    
    # Decoder with skip connections
    d1 = self.d1(b, s4)
    d2 = self.d2(d1, s3)
    d3 = self.d3(d2, s2)
    d4 = self.d4(d3, s1)
    
    # Final pixel-wise output (raw logits)
    x = self.outputs(d4)
    # print (x.shape)
    return x

But After running I am getting following error: Since Generator expects input shape 1 channel, and output shape as 3 (as I am having three classes in Labelled Gray image [1,2,3])… DIscriminator Input as 3 and output 1 But it shows following error

RuntimeError: Given groups=1, weight of size [96, 1, 3, 3], expected input[1, 3, 258, 258] to have 1 channels, but got 3 channels instead

Hi Idrees!

Based on the code you’ve posted, the channels mismatch appears to be in the first layer
of your Generator. However, you don’t show where you instantiate a Generator nor to
what image you might apply that Generator.

If, as might be typical for a CycleGAN, you are feeding the output of your “transformation”
Generator to a second “reconstruction” Generator of the same architecture, you would
be feeding a three-channel image to a generator that expects a single channel.

Print out the shapes of any images to which you are applying one of your Generators,
see if any of them do have three channels, and then ask how many channels they logically
should have.

For your use case would it make sense to instantiate a Generator with three channels?
Or could it make sense to sum or average the three channels of one of those “input”
images to get a one-channel image compatible with the Generator?

Best.

K. Frank

Please see my full code:

#------------------------------------------------------------------------------#

DEPENDENCIES

#------------------------------------------------------------------------------#
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import os
from PIL import Image
import torchvision.models as models
from torchvision.models.segmentation import deeplabv3_resnet50
from matplotlib import pyplot as plt

from sklearn.model_selection import train_test_split

import random
from torch.utils.data import random_split
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ReduceLROnPlateau
import torch.nn.init as init
import random
from collections import deque
import numpy as np
#------------------------------------------------------------------------------#

Custom Dataset Class

#------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#
class CustomSegmentationDataset(Dataset):
def init(self, data_dir, labels_dir, image_transform=None, label_transform=None):
self.data_dir = data_dir
self.labels_dir = labels_dir
self.image_transform = image_transform
self.label_transform = label_transform

    # List all image files in the data directory
    self.image_paths = [os.path.join(data_dir, fname)   for fname in os.listdir(data_dir) if fname.endswith('.jpg') or fname.endswith('.png')]
    self.label_paths = [os.path.join(labels_dir, fname) for fname in os.listdir(labels_dir) if fname.endswith('.jpg') or fname.endswith('.png')]

    # Ensure that the number of data and labels match
    assert len(self.image_paths) == len(self.label_paths), "Mismatch between number of images and labels"

def __len__(self):
    return len(self.image_paths)

def __getitem__(self, idx):
    image = Image.open(self.image_paths[idx]).convert('L')
    label = Image.open(self.label_paths[idx]).convert('L')  # Assuming labels are images
    
    # Apply the transformations if any
    if self.image_transform:
        image = self.image_transform(image)
    
    if self.label_transform:
        label = self.label_transform(label)
    return image, label   

#------------------------------------------------------------------------------#

Dual ATTENTION (Channel Attention + Spatial Attention)

#--------------------------------------------------------------------------------#
class DualAttention(nn.Module):
def init(self, in_channels):
super(DualAttention, self).init()
self.channel_attention = ChannelAttention(in_channels)
self.spatial_attention = SpatialAttention()

def forward(self, x):
    x = self.channel_attention(x)
    x = self.spatial_attention(x)
    return x

class ChannelAttention(nn.Module):
def init(self, in_channels):
super(ChannelAttention, self).init()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // 16, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_channels // 16, in_channels, bias=False),
nn.Sigmoid()
)

def forward(self, x):
    avg_out = self.avg_pool(x)
    max_out = self.max_pool(x)
    out = avg_out + max_out
    out = out.view(out.size(0), -1)  # Flatten to (batch_size, channels)
    out = self.fc(out).view(out.size(0), out.size(1), 1, 1)
    return x * out  # Apply attention to the original feature map

class SpatialAttention(nn.Module):
def init(self):
super(SpatialAttention, self).init()
self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
    avg_out = torch.mean(x, dim=1, keepdim=True)
    max_out, _ = torch.max(x, dim=1, keepdim=True)
    x_cat = torch.cat([avg_out, max_out], dim=1)
    out = self.conv1(x_cat)
    return x * self.sigmoid(out)  # Apply attention to the original feature map

#------------------------------------------------------------------------------#

Custom Generator (Encoder Decoder Based)

#------------------------------------------------------------------------------#
class conv_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, padding_mode=‘reflect’)
self.bn1 = nn.BatchNorm2d(out_c, affine=True)
self.relu = nn.ReLU(inplace=True)

def forward(self, inputs):
    x = self.conv1(inputs)
    x = self.bn1(x)
    x = self.relu(x)
    return x

class conv1_cross_1(nn.Module):
def init(self, in_channels, out_channels):
super().init()
self.conv1x1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
)

def forward(self, x):
    return self.conv1x1(x)

class encoder_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.conv = conv_block(in_c, out_c)
self.attention = DualAttention(out_c) # Add attention mechanism
self.pool = nn.MaxPool2d((2, 2))

def forward(self, inputs):
    x = self.conv(inputs)
    p = self.pool(x)
    return x, p

class decoder_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.up = nn.Sequential(
nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0),
nn.BatchNorm2d(out_c, affine=True),
nn.ReLU(inplace=True)
)
self.attention = DualAttention(out_c) # Add attention mechanism

def forward(self, inputs, skip):
    x = self.up(inputs)
    x = x + skip  # Summation-based skip connection
    x = self.attention(x)  # Apply attention after the skip connection
    return x

class classifier_block(nn.Module):
def init(self, in_c, out_c):
super().init()
self.outputs = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)
)

def forward(self, inputs):
    x = self.outputs(inputs)
    return x

class Generator(nn.Module):
def init(self):
super().init()
self.e1 = encoder_block(1, 96)
self.e2 = encoder_block(96, 96)
self.e3 = encoder_block(96, 96)
self.e4 = encoder_block(96, 96)

    # Bottleneck
    self.bottleneck = nn.Sequential(
        nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(96),
        nn.ReLU(inplace=True)
    )
    
    self.d1 = decoder_block(96, 96)
    self.d2 = decoder_block(96, 96)
    self.d3 = decoder_block(96, 96)
    self.d4 = decoder_block(96, 96)
    self.outputs = classifier_block(96, 3)  # Output channels = 3 for 3 classes

def forward(self, inputs):  
    
    # Encoder
    s1, p1 = self.e1(inputs)
    s2, p2 = self.e2(p1)
    s3, p3 = self.e3(p2)
    s4, p4 = self.e4(p3)
    
    # Bottleneck
    b = self.bottleneck(p4)
    
    # Decoder with skip connections
    d1 = self.d1(b, s4)
    d2 = self.d2(d1, s3)
    d3 = self.d3(d2, s2)
    d4 = self.d4(d3, s1)
    
    # Output
    x = self.outputs(d4)  # Output shape: (batch_size, 3, height, width)
    # print ('at generation level')
    # print (x.shape)
    return x

#------------------------------------------------------------------------------#

Custom Discriminator (encoder decoder based)

#------------------------------------------------------------------------------#
#-------------------------------------------------------------------------------#

custom pixelwise encoder-decoder based discriminator

#-------------------------------------------------------------------------------#
class EncoderBlock(nn.Module):
def init(self, in_channels, out_channels):
super(EncoderBlock, self).init()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, inplace=True)
)

def forward(self, x):
    return self.conv(x)

class DecoderBlock(nn.Module):
def init(self, in_channels, out_channels):
super(DecoderBlock, self).init()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x, skip):
    x = self.conv(x)
    x = x + skip  # addition based skip connection
    # x = torch.cat((x, skip), dim=1)  # Skip connection
    return x

class Discriminator(nn.Module):
def init(self):
super(Discriminator, self).init()

    # Encoder layers
    self.e1 = EncoderBlock(3, 96)
    self.e2 = EncoderBlock(96, 96)
    self.e3 = EncoderBlock(96, 96)
    self.e4 = EncoderBlock(96, 96)
    
    # Bottleneck
    self.bottleneck = nn.Sequential(
        nn.Conv2d(96, 96, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.2, inplace=True)
    )
    
    # Decoder layers (for skip connections)
    self.d1 = DecoderBlock(96, 96)
    self.d2 = DecoderBlock(96, 96)
    self.d3 = DecoderBlock(96, 96)
    self.d4 = DecoderBlock(96, 96)
    
    # Final output layer (pixel-wise)
    self.outputs = nn.Conv2d(96, 1, kernel_size=1, stride=1, padding=0, bias=False)  # Output 1 channel per pixel

def forward(self, x):
    # Encoder forward pass
    # print ("before discrinaation",x.shape)
    s1 = self.e1(x)
    s2 = self.e2(s1)
    s3 = self.e3(s2)
    s4 = self.e4(s3)
    # print ('before bootleneck at disc level',s4.shape)
    # Bottleneck (intermediate layer)
    b = self.bottleneck(s4)
    # print ('After bootleneck at disc level',b.shape)
    
    # Decoder with skip connections
    d1 = self.d1(b, s4)
    d2 = self.d2(d1, s3)
    d3 = self.d3(d2, s2)
    d4 = self.d4(d3, s1)
    
    # Final pixel-wise output (raw logits)
    x = self.outputs(d4)
    # print ('at discriminator level')
    # print (x.shape)
    return x

#-----------------------------------------------------------------------------#

#-------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#

Custom Segmentation Unet

#------------------------------------------------------------------------------#
class SegmentationModel(nn.Module):
def init(self, in_channels=1, out_channels=3):
super(SegmentationModel, self).init()
self.encoder1 = self.conv_block(in_channels, 96)
self.encoder2 = self.conv_block(96, 96)
self.encoder3 = self.conv_block(96, 96)
self.encoder4 = self.conv_block(96, 96)
self.bottleneck = self.conv_block(96, 96)
self.decoder4 = self.conv_block(96, 96)
self.decoder3 = self.conv_block(96, 96)
self.decoder2 = self.conv_block(96, 96)
self.decoder1 = self.conv_block(96, 96)
self.final_conv = nn.Conv2d(96, out_channels, kernel_size=1)

def conv_block(self, in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
    )

def forward(self, x):
    enc1 = self.encoder1(x)
    enc2 = self.encoder2(enc1)
    enc3 = self.encoder3(enc2)
    enc4 = self.encoder4(enc3)
    bottleneck = self.bottleneck(enc4)
    dec4 = self.decoder4(bottleneck)
    dec3 = self.decoder3(dec4 + enc4)
    dec2 = self.decoder2(dec3 + enc3)
    dec1 = self.decoder1(dec2 + enc2)
    x = self.final_conv(dec1 + enc1)
    # print ('at segmentation')
    # print (x.shape)
    return x

#------------------------------------------------------------------------------#

Custom CycleGAN

#------------------------------------------------------------------------------#
class CycleGAN(nn.Module):
def init(self, num_classes):
super(CycleGAN, self).init()
self.G_A2B = Generator()
self.G_B2A = Generator()
self.D_A = Discriminator()
self.D_B = Discriminator()
self.Seg = SegmentationModel()
self.criterion_adv = nn.MSELoss()
self.criterion_cycle = nn.L1Loss()
self.criterion_identity = nn.L1Loss()
self.criterion_seg = nn.CrossEntropyLoss()

def forward(self, real_A, real_B, seg_A, seg_B):
    fake_B = self.G_A2B(real_A)
    fake_A = self.G_B2A(real_B)
    D_A_fake = self.D_A(fake_A.detach())
    D_B_fake = self.D_B(fake_B.detach())
    cycle_A = self.G_B2A(fake_B)
    cycle_B = self.G_A2B(fake_A)
    identity_A = self.G_B2A(real_A)
    identity_B = self.G_A2B(real_B)
    D_A_real = self.D_A(real_A)
    D_B_real = self.D_B(real_B)
    seg_fake_B = self.Seg(fake_B)
    seg_fake_A = self.Seg(fake_A)
    return {
        'fake_A': fake_A,
        'fake_B': fake_B,
        'cycle_A': cycle_A,
        'cycle_B': cycle_B,
        'identity_A': identity_A,
        'identity_B': identity_B,
        'D_A_real': D_A_real,
        'D_A_fake': D_A_fake,
        'D_B_real': D_B_real,
        'D_B_fake': D_B_fake,
        'seg_fake_A': seg_fake_A,
        'seg_fake_B': seg_fake_B
    }

#------------------------------------------------------------------------------#

IoU Calculation

#------------------------------------------------------------------------------#
def calculate_iou(pred, target, num_classes):
pred = pred.view(-1) # Flatten the prediction tensor
target = target.view(-1) # Flatten the target tensor
ious =
for c in range(num_classes):
pred_c = (pred == c) # Boolean mask for predicted class
target_c = (target == c) # Boolean mask for target class

    intersection = (pred_c & target_c).sum().float()  # Intersection
    union = (pred_c | target_c).sum().float()         # Union

    # Handle the case where union is zero (no pixels for this class)
    if union == 0:
        ious.append(0.0)  # IoU is 0 if there are no pixels for this class
    else:
        iou = intersection / union
        ious.append(iou)
# Compute the mean IoU across all classes
mean_iou = torch.mean(torch.tensor(ious))
return mean_iou  # Mean IoU is guaranteed to be in [0, 1]

#-------------------------------------------------------------------------------#

#------------------------------------------------------------------------------#

Training Loop

#------------------------------------------------------------------------------#
def train_cyclegan(real_A, real_B, seg_A, seg_B, model, optimizer_G, optimizer_D, epoch, batch_idx):
model.train()
device = next(model.parameters()).device

# Move data to the correct device
real_A = real_A.to(device)  # No need for requires_grad=True for inputs
real_B = real_B.to(device)  # No need for requires_grad=True for inputs
seg_A = seg_A.long().to(device) if seg_A is not None else None
seg_B = seg_B.long().to(device) if seg_B is not None else None

# Forward pass
outputs = model(real_A, real_B, seg_A, seg_B)

# Generator losses
loss_G_A2B = model.criterion_adv(outputs['D_B_fake'], torch.ones_like(outputs['D_B_fake']).to(device))
loss_G_B2A = model.criterion_adv(outputs['D_A_fake'], torch.ones_like(outputs['D_A_fake']).to(device))
loss_G_adv = loss_G_A2B + loss_G_B2A

loss_cycle_A = model.criterion_cycle(outputs['cycle_A'], real_A)
loss_cycle_B = model.criterion_cycle(outputs['cycle_B'], real_B)
loss_cycle = loss_cycle_A + loss_cycle_B

loss_identity_A = model.criterion_identity(outputs['identity_A'], real_A)
loss_identity_B = model.criterion_identity(outputs['identity_B'], real_B)
loss_identity = loss_identity_A + loss_identity_B

# Segmentation loss and IoU calculation
loss_seg = 0.0
mean_iou = 0.0
if seg_A is not None and seg_B is not None:
    seg_fake_A = outputs['seg_fake_A'].float()
    seg_fake_B = outputs['seg_fake_B'].float()
    seg_A = seg_A.squeeze(1).long()  # Ensure it is in shape (batch_size, height, width)
    seg_B = seg_B.squeeze(1).long()  # Ensure it is in shape (batch_size, height, width)

    loss_seg_A = model.criterion_seg(seg_fake_A, seg_A)
    loss_seg_B = model.criterion_seg(seg_fake_B, seg_B)
    loss_seg = loss_seg_A + loss_seg_B

    pred_A = torch.argmax(outputs['seg_fake_A'], dim=1)  # Get the predicted classes
    pred_B = torch.argmax(outputs['seg_fake_B'], dim=1)  # Same for B

    mean_iou_A = calculate_iou(pred_A, seg_A, num_classes=3)
    mean_iou_B = calculate_iou(pred_B, seg_B, num_classes=3)
    mean_iou = (mean_iou_A + mean_iou_B) / 2

# Total generator loss
loss_G = loss_G_adv + 10.0 * loss_cycle + 5.0 * loss_identity + 1.0 * loss_seg

# Discriminator losses
loss_D_A = model.criterion_adv(outputs['D_A_real'], torch.ones_like(outputs['D_A_real']).to(device)) + \
           model.criterion_adv(outputs['D_A_fake'], torch.zeros_like(outputs['D_A_fake']).to(device))
loss_D_B = model.criterion_adv(outputs['D_B_real'], torch.ones_like(outputs['D_B_real']).to(device)) + \
           model.criterion_adv(outputs['D_B_fake'], torch.zeros_like(outputs['D_B_fake']).to(device))
loss_D = loss_D_A + loss_D_B

# Backpropagation
optimizer_G.zero_grad()
loss_G.backward(retain_graph=True)  # Retain the graph for the next backward pass
optimizer_G.step()

optimizer_D.zero_grad()
loss_D.backward()  # No need to retain the graph here
optimizer_D.step()

return loss_G.item(), loss_D.item(), mean_iou

#----------------------------------------------------------------------------------#
def evaluate_cyclegan(real_A, real_B, seg_A, seg_B, model):
# Forward pass
outputs = model(real_A, real_B, seg_A, seg_B)

# Generator losses
loss_G_A2B = model.criterion_adv(outputs['D_B_fake'], torch.ones_like(outputs['D_B_fake']).to(device))
loss_G_B2A = model.criterion_adv(outputs['D_A_fake'], torch.ones_like(outputs['D_A_fake']).to(device))
loss_G_adv = loss_G_A2B + loss_G_B2A

loss_cycle_A = model.criterion_cycle(outputs['cycle_A'], real_A)
loss_cycle_B = model.criterion_cycle(outputs['cycle_B'], real_B)
loss_cycle = loss_cycle_A + loss_cycle_B

loss_identity_A = model.criterion_identity(outputs['identity_A'], real_A)
loss_identity_B = model.criterion_identity(outputs['identity_B'], real_B)
loss_identity = loss_identity_A + loss_identity_B

# Segmentation loss and IoU calculation
loss_seg = 0.0
mean_iou = 0.0
if seg_A is not None and seg_B is not None:
    seg_fake_A = outputs['seg_fake_A'].float()
    seg_fake_B = outputs['seg_fake_B'].float()
    seg_A = seg_A.squeeze(1).long()
    seg_B = seg_B.squeeze(1).long()

    loss_seg_A = model.criterion_seg(seg_fake_A, seg_A)
    loss_seg_B = model.criterion_seg(seg_fake_B, seg_B)
    loss_seg = loss_seg_A + loss_seg_B

    pred_A = torch.argmax(outputs['seg_fake_A'], dim=1)
    pred_B = torch.argmax(outputs['seg_fake_B'], dim=1)

    mean_iou_A = calculate_iou(pred_A, seg_A, num_classes=3)
    mean_iou_B = calculate_iou(pred_B, seg_B, num_classes=3)
    mean_iou = (mean_iou_A + mean_iou_B) / 2

# Total generator loss
loss_G = loss_G_adv + 10.0 * loss_cycle + 5.0 * loss_identity + 1.0 * loss_seg

# Discriminator losses (but we don’t need to calculate gradients here)
loss_D_A = model.criterion_adv(outputs['D_A_real'], torch.ones_like(outputs['D_A_real']).to(device)) + \
           model.criterion_adv(outputs['D_A_fake'], torch.zeros_like(outputs['D_A_fake']).to(device))
loss_D_B = model.criterion_adv(outputs['D_B_real'], torch.ones_like(outputs['D_B_real']).to(device)) + \
           model.criterion_adv(outputs['D_B_fake'], torch.zeros_like(outputs['D_B_fake']).to(device))
loss_D = loss_D_A + loss_D_B

return loss_G.item(), loss_D.item(), mean_iou

#------------------------------------------------------------------------------#

INITIALISATION OF WEIGHTS Zero-Centered Initialization

#------------------------------------------------------------------------------#
def zero_centered_init(m):
“”"
Apply zero-centered initialization to the weights of a layer.
“”"
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
nn.init.normal_(m.weight, mean=0.0, std=0.02) # Zero-centered initialization
if m.bias is not None:
nn.init.zeros_(m.bias) # Initialize biases to zero
#------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#
#------------------------------------------------------------------------------#

Main Execution

#------------------------------------------------------------------------------#
if name == “main”:
#------------------------------------------------------------------------------#

Hyperparameters

#------------------------------------------------------------------------------#
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
num_classes = 3
num_epochs = 1
batch_size = 1
learning_rate_G = 0.0001
learning_rate_D = 0.0001
replay_buffer_size = 50 # Size of the replay buffer
#------------------------------------------------------------------------------#

Initialize model and optimizers

#------------------------------------------------------------------------------#
model = CycleGAN(num_classes).to(device)
model.apply(zero_centered_init)
# model.apply(init_weights)
optimizer_G = optim.Adam(list(model.G_A2B.parameters()) + list(model.G_B2A.parameters()), lr=learning_rate_G, betas=(0.5, 0.999))
optimizer_D = optim.Adam(list(model.D_A.parameters()) + list(model.D_B.parameters()), lr=learning_rate_D, betas=(0.5, 0.999))
#---------------------------------------------------------#
# SCHEDULERS
#----------------#
scheduler_G = StepLR(optimizer_G, step_size=30, gamma=0.1)
scheduler_D = StepLR(optimizer_D, step_size=30, gamma=0.1)
#------------------------------------------------------------------------------#

Dataset and DataLoader

#------------------------------------------------------------------------------#
data = r’C:\Users\Idrees Bhat\Desktop\reduced_dataset\source’
labels = r’C:\Users\Idrees Bhat\Desktop\reduced_dataset\target’
image_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5]),])
label_transform = transforms.Compose([transforms.ToTensor()])
# Load the full dataset
dataset = CustomSegmentationDataset(data,labels,image_transform=image_transform, label_transform=label_transform )

# Set the split ratios
train_size = int(0.8 * len(dataset))   # 80% for training
val_size   = int(0.1 * len(dataset))   # 10% for validation
test_size  = len(dataset) - train_size - val_size  # Remaining 10% for testing
# test_size = len(dataset) - train_size
# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
                                                      
# Create DataLoader for each split
train_loader = DataLoader(train_dataset, batch_size= batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset,batch_size= batch_size,    shuffle=False)
test_loader  = DataLoader(test_dataset,batch_size= batch_size,   shuffle=False)
# print ('')
# print (type(train_data))
#-------------------------------------------------------------------------------------------------------#
#  tracking losses
#--------------------------------------------------------------------------------------------------------#
# Initialize lists to store values
gen_losses = []
disc_losses = []
iou_values = []
lr_G_values = []
lr_D_values = []

#-----------------------------------------------------------------------------------------------------------#

TRAINING LOOP

#------------------------------------------------------------------------------------------------------------#
# Training Loop
for epoch in range(num_epochs):
total_gen_loss = 0.0
total_disc_loss = 0.0
total_iou = 0.0
for i, (real_A, real_B) in enumerate(train_loader):
real_A = real_A.to(device)
real_B = real_B.to(device)
# Forward pass to generate fake images
fake_B = model.G_A2B(real_A)
fake_A = model.G_B2A(real_B)

        # Pass real_A and real_B as segmentation maps
        loss_G, loss_D, mean_iou = train_cyclegan(real_A, real_B, real_B, real_B, model, optimizer_G, optimizer_D,epoch, i)
        
         # Get intermediate outputs from the model
        outputs = model(real_A, real_B, real_B, real_B)

        if i % 5 == 0:
             print(f"Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(train_loader)}] Loss_G: {loss_G:.4f} Loss_D: {loss_D:.4f} Mean IoU: {mean_iou:.4f}")
             
        scheduler_G.step()
        scheduler_D.step()

I am doing sementic segmentation with Attention based CycleGAN. I am expecting Gray scale image as input to Generator since my target images are also Gray scale images with semantic segmentation map [1,2,3]. As a multiclass seg problem My Generators output shape should be equal to 3 and Discriminator also expect 3 shape as input and produce at last 1 shape as out put… Here it is showing error. But When I am running with Generator output as 1 shape Discriminators input shape as 1 it works fine… Note, I am using Cross entropy loss fucntion as additional seg loss .

Hi Idrees!

More code isn’t helpful.

My recommendation:

Find the spot in your code where the error actually occurs. It is probably something like:

output = myGenerator (myThreeChannelImage)

where myThreeChannelImage is a tensor of shape [1, 3, 258, 258], as reported in
the error message you posted. (That is, myThreeChannelImage is a batch containing
a single sample that is a 258x258 three-channel image). myGenerator would be one
of the instances of your Generator class in your CycleGAN.

Then verify that this is causing the issue with something like:

singleChannelTestImage = torch.randn (1, 1, 258, 258)
print (singleChannelTestImage.shape)
print ('calling myGenerator (singleChannelTestImage) ...')
_ = myGenerator (singleChannelTestImage)
print ('myGenerator (singleChannelTestImage) returned.')

print (myThreeChannelImage.shape)
print ('calling myGenerator (myThreeChannelImage) ...')
output = myGenerator (myThreeChannelImage)
print ('expected error before here')

If this confirms the analysis of the cause of your error, you then need to understand why
you are trying to pass a three-channel image to a model that expects a single-channel
input. Based on the purpose of this CycleGAN, how many channels should this image
have?

Best.

K. Frank