I am working on an RGB image restoration task using U Net, but my model is always giving me a complete black image as an output.
import torch
import torch.nn as nn
import torch.nn.functional as F
class PromptGenBlock(nn.Module):
def init(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192):
super(PromptGenBlock,self).init()
self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
self.linear_layer = nn.Linear(lin_dim,prompt_len)
self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
def forward(self,x):
B,C,H,W = x.shape
emb = x.mean(dim=(-2,-1))
prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
prompt = torch.sum(prompt,dim=1)
prompt = F.interpolate(prompt,(H,W),mode="bilinear")
prompt = self.conv3x3(prompt)
return prompt
class DoubleConv(nn.Module):
def init(self, in_ch, out_ch):
super(DoubleConv, self).init()
self.double_conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Unet(nn.Module):
def __init__(self , inp_ch , out_ch):
super(Unet, self).__init__()
self.init_conv = DoubleConv(inp_ch , 64)
self.pool = nn.MaxPool2d(2)
self.enc1 = DoubleConv(64 , 64)
self.enc2 = DoubleConv(64 , 128)
self.enc3 = DoubleConv(128 , 256)
self.enc4 = DoubleConv(256 , 512)
self.dec1 = DoubleConv(512 , 256)
self.dec2 = DoubleConv(256 , 128)
self.dec3 = DoubleConv(128 , 64)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.upconv2 = nn.ConvTranspose2d(256 , 128, kernel_size=2, stride=2)
self.upconv1 = nn.ConvTranspose2d(128 , 64 , kernel_size=2, stride=2)
self.prompt1 = PromptGenBlock(256 , 5, 32 , 256)
self.prompt2 = PromptGenBlock(128 , 5 , 64 , 128)
self.prompt3 = PromptGenBlock(64 , 5 , 128 , 64)
self.final_deconv = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
self.final_conv = DoubleConv(64 , out_ch)
def forward(self , x):
init_conv = self.init_conv(x)
enc1 = self.enc1(init_conv)
enc1_down = self.pool(enc1)
enc2 = self.enc2(enc1_down)
enc2_down = self.pool(enc2)
enc3 = self.enc3(enc2_down)
enc3_down = self.pool(enc3)
enc4 = self.enc4(enc3_down)
enc4_down = self.pool(enc4)
dec_1_input = self.upconv3(enc4_down)
prompt1 = self.prompt1(dec_1_input)
dec1_input = torch.cat([enc3_down , prompt1] , dim = 1 )
dec1_output = self.dec1(dec1_input)
dec2_input = self.upconv2(dec1_output)
prompt2 = self.prompt2(dec2_input)
dec2_input = torch.cat([enc2_down , prompt2] , dim = 1 )
dec2_output = self.dec2(dec2_input)
dec3_input = self.upconv1(dec2_output)
prompt3 = self.prompt3(dec3_input)
dec3_input = torch.cat([enc1_down , prompt3] , dim = 1 )
dec3_output = self.dec3(dec3_input)
final_deconv = self.final_deconv(dec3_output)
final_conv = self.final_conv(final_deconv)
return x + final_conv
img = torch.randn((1 , 3 , 256 , 256))
model = Unet(3 , 3)
res = model(img)
print(res.shape)
class CombinedLoss(nn.Module):
def init(self, alpha=1.0, beta=1.0):
super(CombinedLoss, self).init()
self.mse_loss = nn.MSELoss()
self.perceptual_loss = PerceptualLoss()
self.alpha = alpha
self.beta = beta
def forward(self, output, target):
mse = self.mse_loss(output, target)
perceptual = self.perceptual_loss(output, target)
total_loss = (self.alpha * mse +
self.beta * perceptual )
return total_loss
Dataloader
class PoledTrainDataset(Dataset):
none_count_hq = 0
none_count_lq = 0
def __init__(self, target_size=(256, 256)):
self.poled_path = r'Train/Poled'
self.poled_hq = os.path.join(self.poled_path, 'HQ')
self.poled_lq = os.path.join(self.poled_path, 'LQ')
self.poled_hq_img_path = natsorted([os.path.join(self.poled_hq, img) for img in os.listdir(self.poled_hq)])
self.poled_lq_img_path = natsorted([os.path.join(self.poled_lq, img) for img in os.listdir(self.poled_lq)])
self.transform = transforms.Compose([
transforms.Resize(target_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def __len__(self):
return len(self.poled_lq_img_path)
def __getitem__(self, idx):
while True:
hq_path = self.poled_hq_img_path[idx]
lq_path = self.poled_lq_img_path[idx]
hq_img = cv2.imread(hq_path)
lq_img = cv2.imread(lq_path)
if hq_img is None:
idx = (idx + 1) % len(self.poled_hq_img_path)
continue
if lq_img is None:
idx = (idx + 1) % len(self.poled_lq_img_path)
continue
hq_img = cv2.cvtColor(hq_img, cv2.COLOR_BGR2RGB)
lq_img = cv2.cvtColor(lq_img, cv2.COLOR_BGR2RGB)
# Convert to PIL Image and resize
hq_img_pil = Image.fromarray(hq_img)
lq_img_pil = Image.fromarray(lq_img)
hq_img_resized = self.transform(hq_img_pil)
lq_img_resized = self.transform(lq_img_pil)
return hq_img_resized, lq_img_resized