I’ve implemented a UNET style architecture for image denoising and it works well. I’m now trying to add a filter, written with Pytorch, inside my network, after the UNET architecture to further denoise the image before returning an output for backpropagation.
My problem is that after adding this filter, my loss (MSE) is returning tensor NaN.
Any ideas as to why this is?
"""Implementation of neural net denoise """
from datetime import datetime
import os
import torch
from torch import nn
# from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
# from autoencoder_grayscale import AutoEncoder
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt
# from pyramid_loss import LapLoss
# from skimage.color import rgb2gray
from wiener_3d import wiener_3d
torch.set_printoptions(linewidth=120)
# FUNCTION TO ADD NOISE DATA
def add_noise(image, std, device):
""" Add noise to images """
noise = torch.randn(image.size()) * (std)
noise = noise.to(device)
noisy_image = image + noise
return noisy_image
# BATCH NORMALISE IMAGES
def min_max_normalization(tensor, min_value, max_value):
"""Normalizing transformation"""
min_tensor = tensor.min()
tensor = (tensor - min_tensor)
max_tensor = tensor.max()
tensor = tensor / max_tensor
tensor = tensor * (max_value - min_value) + min_value
return tensor
def to_img(image, c, width, height):
"""Reformat tensor for printing """
image = image.view(image.size(0), 1, width, height)
return image
# VISUALISE INDIVIDUAL IMAGES
def show_image(image):
""" plot inidividual images"""
new_image = image.permute(1, 2, 0)
plt.imshow(new_image)
plt.show()
# PLOT ALL IN A BATCH
# just use to_img or view instead...
def plot_batch(images, dimension):
""" Plot every image in the batch"""
grid = torchvision.utils.make_grid(images, nrow=6)
plt.figure(figsize=(dimension, dimension))
plt.imshow(np.transpose(grid, (1, 2, 0)), norm=None)
# DATALOADER
def load_dataset(size_batch, size):
""" Get dataset and return dataloader """
data_path = "test/kodak_validation/"
transformations = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.CenterCrop(size),
transforms.Resize(size),
transforms.ToTensor(),
# transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# transforms.Normalize([0], [1])
])
train_dataset = datasets.ImageFolder(
root=data_path,
transform=transformations
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=size_batch,
num_workers=0,
shuffle=True
)
print("loaded")
return train_loader
class AutoEncoder(nn.Module):
"""Autoencoder simple implementation """
def __init__(self):
super(AutoEncoder, self).__init__()
# Encoder
# conv layer
self.block1 = nn.Sequential(
nn.Conv2d(1, 48, 3, padding=1),
nn.Conv2d(48, 48, 3, padding=1),
nn.MaxPool2d(2),
nn.BatchNorm2d(48),
nn.LeakyReLU(0.1)
)
self.block2 = nn.Sequential(
nn.Conv2d(48, 48, 3, padding=1),
nn.MaxPool2d(2),
nn.BatchNorm2d(48),
nn.LeakyReLU(0.1)
)
self.block3 = nn.Sequential(
nn.Conv2d(48, 48, 3, padding=1),
nn.ConvTranspose2d(48, 48, 2, 2, output_padding=1),
nn.BatchNorm2d(48),
nn.LeakyReLU(0.1)
)
self.block4 = nn.Sequential(
nn.Conv2d(96, 96, 3, padding=1),
nn.Conv2d(96, 96, 3, padding=1),
nn.ConvTranspose2d(96, 96, 2, 2),
nn.BatchNorm2d(96),
nn.LeakyReLU(0.1)
)
self.block5 = nn.Sequential(
nn.Conv2d(144, 96, 3, padding=1),
nn.Conv2d(96, 96, 3, padding=1),
nn.ConvTranspose2d(96, 96, 2, 2),
nn.BatchNorm2d(96),
nn.LeakyReLU(0.1)
)
self.block6 = nn.Sequential(
nn.Conv2d(97, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.Conv2d(64, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.Conv2d(32, 1, 3, padding=1),
nn.LeakyReLU(0.1)
)
def forward(self, x):
# print("input: ", x.shape)
pool1 = self.block1(x)
# print("pool1: ", pool1.shape)
pool2 = self.block2(pool1)
# print("pool2: ", pool2.shape)
pool3 = self.block2(pool2)
# print("pool3: ", pool3.shape)
pool4 = self.block2(pool3)
# print("pool4: ", pool4.shape)
pool5 = self.block2(pool4)
# print("pool5: ", pool5.shape)
upsample5 = self.block3(pool5)
# print("upsample5: ", upsample5.shape)
concat5 = torch.cat((upsample5, pool4), 1)
# print("concat5: ", concat5.shape)
upsample4 = self.block4(concat5)
# print("upsample4: ", upsample4.shape)
concat4 = torch.cat((upsample4, pool3), 1)
# print("concat4: ", concat4.shape)
upsample3 = self.block5(concat4)
# print("upsample3: ", upsample3.shape)
concat3 = torch.cat((upsample3, pool2), 1)
# print("concat3: ", concat3.shape)
upsample2 = self.block5(concat3)
# print("upsample2: ", upsample2.shape)
concat2 = torch.cat((upsample2, pool1), 1)
# print("concat2: ", concat2.shape)
upsample1 = self.block5(concat2)
# print("upsample1: ", upsample1.shape)
concat1 = torch.cat((upsample1, x), 1)
# print("concat1: ", concat1.shape)
output = self.block6(concat1)
# print("output: ", output.shape)
# STEP 2: SUBTRACT NETWORK OUTPUT FROM NOISY IMAGE TO GET TEXTURE MAP
t_map = x - output
# Convert to size for Wiener filtering
for i in range(4):
tensor = t_map[i, :, :, :] # Take each item in batch separately. Could account for this in Wiener instead
# torchvision.utils.save_image(tensor[:, :, :], 'x{}subtract.png'.format(i))
tensor = torch.squeeze(tensor) # Squeeze for Wiener input format
# STEP 3: APPLY WIENER TO TEXTURE MAP
tensor = wiener_3d(tensor, 0.05, 10) # Apply Wiener with specified std and block size
tensor = torch.unsqueeze(tensor, 0) # unsqueeze to put back into block
# torchvision.utils.save_image(tensor[:, :, :], 'x{}wiener_tmap.png'.format(i))
t_map[i, :, :, :] = tensor # put back into block
# STEP 4: ADD FILTERED TEXTURE MAP BACK ONTO NET OUTPUT TO RESTORE DETAIL
filtered_output = output + t_map
return filtered_output
def train_gray(epoch):
train_loss = 0.0
for data in data_loader:
img, _ = data
img = img.to(device)
noisy_img = add_noise(img, 0.05, device)
# forward pass
# STEP 1: APPLY netowrk TO NOISY IMAGE
output = model(noisy_img)
if (epoch % 5 == 0):
torchvision.utils.save_image(filtered_output, 'x_out{}.png'.format(epoch))
loss = criterion(output, img)
# backwards
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()*img.size(0)
train_loss = train_loss/len(data_loader)
print('Epoch: {} \tTraining Loss: {:.6f}'.format(
epoch,
train_loss
))
return train_loss
def checkpoint(epoch, train_loss):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss
}, path+"/model_epoch_{}.pt".format(epoch))
print("Epoch saved")
now = datetime.now()
current_time = now.strftime("%H_%M_%S")
path = "test/test_training_gray/{}".format(current_time)
os.mkdir(path)
width = 112
height = 112
num_epochs = 100
batch_size = 4
learning_rate = 0.0001
data_loader = load_dataset(batch_size, width)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = AutoEncoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(
model.parameters(), lr=learning_rate, weight_decay=1e-5)
############################################################################################
# UNCOMMENT CODE BELOW FOR RESUMING FROM A MODEL
# model = TheModelClass(*args, **kwargs)
# optimizer = TheOptimizerClass(*args, **kwargs)
# model_path = "test/test_training_gray/22_24_55/model_epoch_100.pt"
# save_point = torch.load(model_path)
# model.load_state_dict(save_point['model_state_dict'])
# optimizer.load_state_dict(save_point['optimizer_state_dict'])
# epoch = save_point['epoch']
# #If we wish to do remaining epochs, num_epochs = num_epochs-epoch
# train_loss = save_point['train_loss']
# model.train()
############################################################################################
for i in range(1, num_epochs+1):
train_loss = train_gray(i)
checkpoint(i, train_loss)
print("end")
Below is the filter I am implementing:
import numpy as np
import torch
def add_noise(img, std):
noise = torch.randn(img.size()) * (std)
noisy_image = img + noise
return noisy_image
def wiener_3d(I, noise_std, block_size):
width = I.shape[1]
height = I.shape[0]
IR = torch.zeros(height, width, dtype=torch.float64)
# if(len(list(I.shape)) >= 3):
# frames = I.shape[2]
# else:
# bt = 1
bt = 1
bx = block_size
by = block_size
hbx = bx/2
hby = by/2
hbt = bt/2
sx = (width + hbx - 1)/hbx
sy = (height + hby - 1)/hby
win = torch.ones(by, bx, bt)
win1x = torch.cos((torch.arange(-hbx + .5, hbx - .5 + 1)/bx) * np.pi)
win1y = torch.cos((torch.arange(-hby + .5, hby - .5 + 1)/by) * np.pi)
win1t = torch.cos((torch.arange(-hbt + .5, hbt - .5 + 1)/bt) * np.pi)
for x in range(bx):
for y in range(by):
for t in range(bt):
win[y, x, t] = win1y[y]*win1x[x]*win1t[t]
if(bt == 1):
win = torch.squeeze(win)
Pvv = torch.mean(torch.pow(win, 2))*torch.numel(win)*(noise_std**2)
Pvv = Pvv.double()
bx0 = torch.range(0, bx-1)
by0 = torch.range(0, by-1)
for x in range(0, int((hbx*sx)), int(hbx)):
for y in range(0, int((hby*sy)), int(hby)):
# print(x,y)
#
tx = np.arange(x-hbx+1, x+hbx+1)
validx = np.arange(np.maximum(-tx[0], 0), bx - np.maximum((tx[-1]-width+1), 0))
cx = np.minimum(np.maximum(tx, 0), width-1)
validx = validx.astype(int)
rcx = torch.as_tensor(tx[validx], dtype=torch.long)
bcx = torch.as_tensor(bx0[validx], dtype=torch.long)
ty = np.arange(y-hby+1, y+hby+1)
validy = np.arange(np.maximum(-ty[0], 0), by - np.maximum((ty[-1]-width+1), 0))
cy = np.minimum(np.maximum(ty, 0), width-1)
validy = validy.astype(int)
rcy = torch.as_tensor(ty[validy], dtype=torch.long)
bcy = torch.as_tensor(by0[validy], dtype=torch.long)
cy = torch.as_tensor(cy, dtype=torch.long)
cx = torch.as_tensor(cx, dtype=torch.long)
data_block = torch.index_select(I, 0, cy)
data_block = torch.index_select(data_block, 1, cx)
mean_block = torch.mean(data_block)
win_data_block = (data_block - mean_block)*win
freq_block = torch.rfft(win_data_block, win_data_block.ndim, onesided=False)
Pss = torch.abs(freq_block)**2
Pss = torch.sum(Pss, 2)
Pss = Pss.double()
H = torch.max((Pss-Pvv), torch.zeros(Pss.size(), dtype=torch.double)) / Pss
H = H.unsqueeze(2).repeat(1, 1, 2)
filt_freq_block = H*freq_block
filt_data_block = torch.irfft(filt_freq_block, win_data_block.ndim, onesided=False)
filt_data_block = (filt_data_block + mean_block*win) * win
# hbt = torch.round(hbt)
filt_data_block = torch.index_select(filt_data_block, 0, bcy)
filt_data_block = torch.index_select(filt_data_block, 1, bcx)
IR[rcy[0]:rcy[-1] + 1, rcx[0]:rcx[-1] + 1] = IR[rcy[0]:rcy[-1] + 1, rcx[0]:rcx[-1] + 1] + filt_data_block
return IR
Apologies if my code is hard to read, I’ve been trying to debug and its made things messy. I would be glad to expand on the above code if needed.