CNN and noise filtering

Hello guys, hope you are all alright.

I’m facing a problem here. I’m trying to implement a ResNet18 to be able to filter noise. For that, I’m using the MNIST dataset and injecting noise on it. The code runs fine, no error message is shown, but here comes the problem:

the code:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torchvision import models
import matplotlib.pyplot as plt
from PIL import Image

# class NoisyMNIST
class NoisyMNIST(MNIST):
    def __init__(self, root, train=True, transform=None, noisy_transform=None, download=False):
        super().__init__(root, train=train, transform=transform, download=download)
        self.noisy_transform = noisy_transform

    def __getitem__(self, index):
        img = self.data[index].view(28, 28)  # Remove the normalization and conversion to tensor

        if self.transform is not None:
            clean_img = self.transform(img)

        if self.noisy_transform is not None:
            noisy_img = self.noisy_transform(img)

        return noisy_img, clean_img


#pre-trained resnet
resnet = models.resnet18(pretrained=True)


resnet = nn.Sequential(*list(resnet.children())[:-1])


resnet.add_module('upsample', nn.Upsample(size=(28, 28), mode='bilinear', align_corners=False))
resnet.add_module('leakyrelu', nn.LeakyReLU())
resnet.add_module('dropout', nn.Dropout(p=0.5))
resnet.add_module('last_conv', nn.Conv2d(512, 3, kernel_size=1))
torch.nn.init.kaiming_normal_(resnet.last_conv.weight)

# loss and optimizer
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(resnet.parameters(), lr=0.001)

# Load dataset and add gaussian noise to the images
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.float()),  # Convert to FloatTensor
    transforms.Normalize((0.5,), (0.5,)),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # repeating channel
])

noisy_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.float()),  # Convert to FloatTensor
    transforms.Normalize((0.5,), (0.5,)),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # repeating channel
    transforms.Lambda(lambda x: x + torch.randn_like(x))  # adding noise
])


train_set = NoisyMNIST('./data', train=True, transform=transform, noisy_transform=noisy_transform, download=True)
test_set = NoisyMNIST('./data', train=False, transform=transform, noisy_transform=noisy_transform, download=True)

# Split MNIST dataset
# train_set, valid_set = torch.utils.data.random_split(train_set, [50000, 10000])

# Reducing it to 10%
num_train = int(len(train_set) * 0.1)
num_valid = int(len(test_set) * 0.05)  # Usamos metade para validação e metade para teste
num_test = num_valid

train_set, _ = random_split(train_set, [num_train, len(train_set) - num_train])
valid_set, _ = random_split(test_set, [num_valid, len(test_set) - num_valid])
test_set, _ = random_split(test_set, [num_test, len(test_set) - num_test])

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=64, shuffle=False)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

# To train the model
num_epochs = 10
train_losses = []
valid_losses = []

for i, data in enumerate(train_loader, 0):
    inputs, targets = data
    inputs = inputs.float()
    targets = targets.float()

    print(f"inputs shape: {inputs.shape}")
    print(f"targets shape: {targets.shape}")

for epoch in range(num_epochs):
    print(f"Training epoch {epoch+1}/{num_epochs}...")
    running_loss = 0.0

    # Training
    for i, data in enumerate(train_loader, 0):
        inputs, targets = data
        inputs = inputs.float()
        targets = targets.float()

        optimizer.zero_grad()
        outputs = resnet(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Validating
    with torch.no_grad():
        for i, data in enumerate(valid_loader, 0):
            inputs, targets = data
            inputs = inputs.float()
            targets = targets.float()

            outputs = resnet(inputs)
            loss = criterion(outputs, targets)

            valid_losses.append(loss.item())

    train_losses.append(running_loss/len(train_loader))

print("Training completed!")

# Evaluate model
print("Evaluating model...")
test_losses = []
with torch.no_grad():
    for i, data in enumerate(test_loader, 0):
        inputs, targets = data
        inputs = inputs.float()
        targets = targets.float()

        outputs = resnet(inputs)
        loss = criterion(outputs, targets)

        test_losses.append(loss.item())

# Plot loss 
plt.plot(train_losses, label='Treino')
plt.plot(valid_losses, label='Validação')
plt.legend()
plt.title('Perda')
plt.show()

# Save the trained model
#torch.save(resnet.state_dict(), './resnet_noise_removal.pth')

# Show the input images and the denoised images
plt.figure(figsize=(10, 5))
for i in range(5):
    plt.subplot(2, 5, i+1)
    plt.imshow(inputs[i].permute(1, 2, 0).numpy(), cmap='gray')
    plt.title('Input')
    plt.axis('off')

    plt.subplot(2, 5, i+6)
    plt.imshow(outputs[i].detach().permute(1, 2, 0).numpy(), cmap='gray')
    plt.title('Denoised')
    plt.axis('off')
plt.show()

progress

It’s not returning the image without noise, it’s returning all black.

What am I doing wrong? What can I do to fix this?

Thanks!

Try to overfit a tiny subset of your data (e.g. just 10 samples) by playing around with the hyperparameters of your training. Once your model is able to overfit the data you could try to scale up the use case again.

Here are some things I would try (can’t guarantee that any of them will work though).

  1. Drop the normalization of (0.5,), (0.5,) since that’s not how the data was normalized when the pre-trained model was trained. Just try w/o any normalization for now. Additionally, the model was trained on colour images so the input distribution is probably completely different.
  2. Don’t use a pre-trained model. For this specific use-case of MNIST images, you probably don’t need a pre-trained model - and it’s probably hurting you.
  3. Check the range of output values. Ideally, you want them to be in the range [0.0 … 1.0]. If they aren’t (which is likely) then you can force it by applying a sigmoid() on the output tensor.

HTH and all the best!

Thanks @ptrblck and @dhruvbird, for your kind suggestions…

I tried all of them, but still facing the same problem. Here’s what I’ve tried:

  1. I changed to a model that was not pre-trained: resnet = models.resnet34(pretrained=False)
  2. Changed the device to work on a GPU
  3. Dropped the normalization of (0.5,), (0.5,)
  4. I am applying sigmoid() on the output tensor
  5. I am using now only one channel on the images
  6. Changed the num_epochs to 60
  7. I tried the training with 2 different loss functions: L1Loss() and BCELoss()

With BCELoss(), working on 1 channel, 60 epochs of training:

image image

With L1Loss(), working on 1 channel, 60 epochs of training:

imageimage

It’s quite intriguing…

I’m not sure if this is entirely correct. I mean it adds random noise to your image but changes the range of values from [0.0 … 1.0] to outside this range. Would you consider adding noise that has mean 0.5 and standard deviation 0.5 and then clip the values to [0.0 … 1.0] if they fall outside the range?

It’s not clear to me how one would use BCELoss() for this problem. Please could you try MSE loss instead?

Hey, @dhruvbird, take a look. I added this function:

def add_gaussian_noise(image):
      noise = torch.randn(image.size()) * 0.5 + 0.5
      noisy_image = image + noise
      noisy_image = torch.clamp(noisy_image, 0., 1.)
      return noisy_image

on:

noisy_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.float()),  # Convert to FloatTensor
    transforms.Lambda(add_gaussian_noise)  # add noise
])

and changed the loss function to MSELoss(), here’s the outputs:

imageimage