I'm adding a filter to my network but my loss is returning Nan!

I’ve implemented a UNET style architecture for image denoising and it works well. I’m now trying to add a filter, written with Pytorch, inside my network, after the UNET architecture to further denoise the image before returning an output for backpropagation.

My problem is that after adding this filter, my loss (MSE) is returning tensor NaN.

Any ideas as to why this is?

"""Implementation of neural net denoise """

from datetime import datetime

import os

import torch

from torch import nn

# from torch.autograd import Variable

import torchvision

import torchvision.transforms as transforms

# from autoencoder_grayscale import AutoEncoder

import torchvision.datasets as datasets

import numpy as np

import matplotlib.pyplot as plt

# from pyramid_loss import LapLoss

# from skimage.color import rgb2gray

from wiener_3d import wiener_3d

torch.set_printoptions(linewidth=120)

# FUNCTION TO ADD NOISE DATA

def add_noise(image, std, device):

    """ Add noise to images """

    noise = torch.randn(image.size()) * (std)

    noise = noise.to(device)

    noisy_image = image + noise

    return noisy_image

# BATCH NORMALISE IMAGES

def min_max_normalization(tensor, min_value, max_value):

    """Normalizing transformation"""

    min_tensor = tensor.min()

    tensor = (tensor - min_tensor)

    max_tensor = tensor.max()

    tensor = tensor / max_tensor

    tensor = tensor * (max_value - min_value) + min_value

    return tensor

def to_img(image, c, width, height):

    """Reformat tensor for printing """

    image = image.view(image.size(0), 1, width, height)

    return image

# VISUALISE INDIVIDUAL IMAGES

def show_image(image):

    """ plot inidividual images"""

    new_image = image.permute(1, 2, 0)

    plt.imshow(new_image)

    plt.show()

# PLOT ALL IN A BATCH

# just use to_img or view instead...

def plot_batch(images, dimension):

    """ Plot every image in the batch"""

    grid = torchvision.utils.make_grid(images, nrow=6)

    plt.figure(figsize=(dimension, dimension))

    plt.imshow(np.transpose(grid, (1, 2, 0)), norm=None)

# DATALOADER

def load_dataset(size_batch, size):

    """ Get dataset and return dataloader """

    data_path = "test/kodak_validation/"

    transformations = transforms.Compose([

        transforms.Grayscale(num_output_channels=1),

        transforms.CenterCrop(size),

        transforms.Resize(size),

        transforms.ToTensor(),

        # transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),

        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        # transforms.Normalize([0], [1])

        ])

    train_dataset = datasets.ImageFolder(

        root=data_path,

        transform=transformations

    )

    train_loader = torch.utils.data.DataLoader(

        train_dataset,

        batch_size=size_batch,

        num_workers=0,

        shuffle=True

    )

    print("loaded")

    return train_loader

class AutoEncoder(nn.Module):

    """Autoencoder simple implementation """

    def __init__(self):

        super(AutoEncoder, self).__init__()

        # Encoder

        # conv layer

        self.block1 = nn.Sequential(

            nn.Conv2d(1, 48, 3, padding=1),

            nn.Conv2d(48, 48, 3, padding=1),

            nn.MaxPool2d(2),

            nn.BatchNorm2d(48),

            nn.LeakyReLU(0.1)

        )

        self.block2 = nn.Sequential(

            nn.Conv2d(48, 48, 3, padding=1),

            nn.MaxPool2d(2),

            nn.BatchNorm2d(48),

            nn.LeakyReLU(0.1)

        )

        self.block3 = nn.Sequential(

            nn.Conv2d(48, 48, 3, padding=1),

            nn.ConvTranspose2d(48, 48, 2, 2, output_padding=1),

            nn.BatchNorm2d(48),

            nn.LeakyReLU(0.1)

        )

        self.block4 = nn.Sequential(

            nn.Conv2d(96, 96, 3, padding=1),

            nn.Conv2d(96, 96, 3, padding=1),

            nn.ConvTranspose2d(96, 96, 2, 2),

            nn.BatchNorm2d(96),

            nn.LeakyReLU(0.1)

        )

        self.block5 = nn.Sequential(

            nn.Conv2d(144, 96, 3, padding=1),

            nn.Conv2d(96, 96, 3, padding=1),

            nn.ConvTranspose2d(96, 96, 2, 2),

            nn.BatchNorm2d(96),

            nn.LeakyReLU(0.1)

        )

        self.block6 = nn.Sequential(

            nn.Conv2d(97, 64, 3, padding=1),

            nn.BatchNorm2d(64),

            nn.Conv2d(64, 32, 3, padding=1),

            nn.BatchNorm2d(32),

            nn.Conv2d(32, 1, 3, padding=1),

            nn.LeakyReLU(0.1)

        )

    def forward(self, x):

        # print("input: ", x.shape)

        pool1 = self.block1(x)

        # print("pool1: ", pool1.shape)

        pool2 = self.block2(pool1)

        # print("pool2: ", pool2.shape)

        pool3 = self.block2(pool2)

        # print("pool3: ", pool3.shape)

        pool4 = self.block2(pool3)

        # print("pool4: ", pool4.shape)

        pool5 = self.block2(pool4)

        # print("pool5: ", pool5.shape)

        upsample5 = self.block3(pool5)

        # print("upsample5: ", upsample5.shape)

        concat5 = torch.cat((upsample5, pool4), 1)

        # print("concat5: ", concat5.shape)

        upsample4 = self.block4(concat5)

        # print("upsample4: ", upsample4.shape)

        concat4 = torch.cat((upsample4, pool3), 1)

        # print("concat4: ", concat4.shape)

        upsample3 = self.block5(concat4)

        # print("upsample3: ", upsample3.shape)

        concat3 = torch.cat((upsample3, pool2), 1)

        # print("concat3: ", concat3.shape)

        upsample2 = self.block5(concat3)

        # print("upsample2: ", upsample2.shape)

        concat2 = torch.cat((upsample2, pool1), 1)

        # print("concat2: ", concat2.shape)

        upsample1 = self.block5(concat2)

        # print("upsample1: ", upsample1.shape)

        concat1 = torch.cat((upsample1, x), 1)

        # print("concat1: ", concat1.shape)

        output = self.block6(concat1)

        # print("output: ", output.shape)

         # STEP 2: SUBTRACT NETWORK OUTPUT FROM NOISY IMAGE TO GET TEXTURE MAP

        t_map = x - output

        # Convert to size for Wiener filtering 

        for i in range(4):

            tensor = t_map[i, :, :, :]                 # Take each item in batch separately. Could account for this in Wiener instead

            # torchvision.utils.save_image(tensor[:, :, :], 'x{}subtract.png'.format(i))

            tensor = torch.squeeze(tensor)              # Squeeze for Wiener input format

            # STEP 3: APPLY WIENER TO TEXTURE MAP

            tensor = wiener_3d(tensor, 0.05, 10)        # Apply Wiener with specified std and block size

            tensor = torch.unsqueeze(tensor, 0)         # unsqueeze to put back into block

            # torchvision.utils.save_image(tensor[:, :, :], 'x{}wiener_tmap.png'.format(i))

            t_map[i, :, :, :] = tensor                  # put back into block

        # STEP 4: ADD FILTERED TEXTURE MAP BACK ONTO NET OUTPUT TO RESTORE DETAIL

        filtered_output = output + t_map

        return filtered_output

def train_gray(epoch):

    train_loss = 0.0

    for data in data_loader:

        img, _ = data

        img = img.to(device)

        noisy_img = add_noise(img, 0.05, device)

        # forward pass

        # STEP 1: APPLY netowrk TO NOISY IMAGE

        output = model(noisy_img)

        if (epoch % 5 == 0):

            torchvision.utils.save_image(filtered_output, 'x_out{}.png'.format(epoch))

        loss = criterion(output, img)

        # backwards

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        train_loss += loss.item()*img.size(0)

    train_loss = train_loss/len(data_loader)

    print('Epoch: {} \tTraining Loss: {:.6f}'.format(

        epoch,

        train_loss

        ))

    return train_loss

def checkpoint(epoch, train_loss):

    torch.save({

            'epoch': epoch,

            'model_state_dict': model.state_dict(),

            'optimizer_state_dict': optimizer.state_dict(),

            'train_loss': train_loss

            }, path+"/model_epoch_{}.pt".format(epoch))

    print("Epoch saved")

now = datetime.now()

current_time = now.strftime("%H_%M_%S")

path = "test/test_training_gray/{}".format(current_time)

os.mkdir(path)

width = 112

height = 112

num_epochs = 100

batch_size = 4

learning_rate = 0.0001

data_loader = load_dataset(batch_size, width)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

model = AutoEncoder().to(device)

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(

    model.parameters(), lr=learning_rate, weight_decay=1e-5)

############################################################################################

# UNCOMMENT CODE BELOW FOR RESUMING FROM A MODEL

# model = TheModelClass(*args, **kwargs)

# optimizer = TheOptimizerClass(*args, **kwargs)

# model_path = "test/test_training_gray/22_24_55/model_epoch_100.pt"

# save_point = torch.load(model_path)

# model.load_state_dict(save_point['model_state_dict'])

# optimizer.load_state_dict(save_point['optimizer_state_dict'])

# epoch = save_point['epoch']

# #If we wish to do remaining epochs, num_epochs = num_epochs-epoch

# train_loss = save_point['train_loss']

# model.train()

############################################################################################

for i in range(1, num_epochs+1):

    train_loss = train_gray(i)

    checkpoint(i, train_loss)

print("end")

Below is the filter I am implementing:

import numpy as np
import torch


def add_noise(img, std):
    noise = torch.randn(img.size()) * (std)
    noisy_image = img + noise
    return noisy_image


def wiener_3d(I, noise_std, block_size):

    width = I.shape[1]
    height = I.shape[0]
    IR = torch.zeros(height, width, dtype=torch.float64)
    # if(len(list(I.shape)) >= 3):
    #    frames = I.shape[2]
    # else:
    #   bt = 1
    bt = 1
    bx = block_size
    by = block_size

    hbx = bx/2
    hby = by/2
    hbt = bt/2

    sx = (width + hbx - 1)/hbx
    sy = (height + hby - 1)/hby

    win = torch.ones(by, bx, bt)
    win1x = torch.cos((torch.arange(-hbx + .5, hbx - .5 + 1)/bx) * np.pi)
    win1y = torch.cos((torch.arange(-hby + .5, hby - .5 + 1)/by) * np.pi)
    win1t = torch.cos((torch.arange(-hbt + .5, hbt - .5 + 1)/bt) * np.pi)

    for x in range(bx):
        for y in range(by):
            for t in range(bt):
                win[y, x, t] = win1y[y]*win1x[x]*win1t[t]

    if(bt == 1):
        win = torch.squeeze(win)

    Pvv = torch.mean(torch.pow(win, 2))*torch.numel(win)*(noise_std**2)
    Pvv = Pvv.double()
    bx0 = torch.range(0, bx-1)
    by0 = torch.range(0, by-1)

    for x in range(0, int((hbx*sx)), int(hbx)):
        for y in range(0, int((hby*sy)), int(hby)):
            # print(x,y)
            #
            tx = np.arange(x-hbx+1, x+hbx+1)
            validx = np.arange(np.maximum(-tx[0], 0), bx - np.maximum((tx[-1]-width+1), 0))
            cx = np.minimum(np.maximum(tx, 0), width-1)
            validx = validx.astype(int)
            rcx = torch.as_tensor(tx[validx], dtype=torch.long)
            bcx = torch.as_tensor(bx0[validx], dtype=torch.long)

            ty = np.arange(y-hby+1, y+hby+1)
            validy = np.arange(np.maximum(-ty[0], 0), by - np.maximum((ty[-1]-width+1), 0))
            cy = np.minimum(np.maximum(ty, 0), width-1)
            validy = validy.astype(int)
            rcy = torch.as_tensor(ty[validy], dtype=torch.long)
            bcy = torch.as_tensor(by0[validy], dtype=torch.long)

            cy = torch.as_tensor(cy, dtype=torch.long)
            cx = torch.as_tensor(cx, dtype=torch.long)
            data_block = torch.index_select(I, 0, cy)
            data_block = torch.index_select(data_block, 1, cx)

            mean_block = torch.mean(data_block)
            win_data_block = (data_block - mean_block)*win

            freq_block = torch.rfft(win_data_block, win_data_block.ndim, onesided=False)
            Pss = torch.abs(freq_block)**2
            Pss = torch.sum(Pss, 2)
            Pss = Pss.double()
            H = torch.max((Pss-Pvv), torch.zeros(Pss.size(), dtype=torch.double)) / Pss

            H = H.unsqueeze(2).repeat(1, 1, 2)

            filt_freq_block = H*freq_block
            filt_data_block = torch.irfft(filt_freq_block, win_data_block.ndim, onesided=False)
            filt_data_block = (filt_data_block + mean_block*win) * win
            # hbt = torch.round(hbt)

            filt_data_block = torch.index_select(filt_data_block, 0, bcy)
            filt_data_block = torch.index_select(filt_data_block, 1, bcx)
            IR[rcy[0]:rcy[-1] + 1, rcx[0]:rcx[-1] + 1] = IR[rcy[0]:rcy[-1] + 1, rcx[0]:rcx[-1] + 1] + filt_data_block
    return IR

Apologies if my code is hard to read, I’ve been trying to debug and its made things messy. I would be glad to expand on the above code if needed.

Update: I’ve found that my filter is introducing NaN (Wiener being called after the UNet). Im trying to narrow down where that is occurring within the filter now.

i also get nan with rfft. Any clue?

Apologies for not coming back to this thread. If I recall correctly, in the end, the problem was not with rfft() but with my multiplication operation.

I analyzed my tensors and found that after this operation, NaNs were present. This was because I was trying to multiply by zero. To get around this, I added a small “eps”/epsilon to the data before multiplying it, i.e. eps = 1e-15.

I have also since moved onto the most recent version of PyTorch which deprecated the rfft function for fft.fftn().

Thanks, the eps trick works. But unfortunately I am still stuck at PyTorch 1.4.0 and fftn is not available only rfft.