Hello guys, hope you are all alright.
I’m facing a problem here. I’m trying to implement a ResNet18 to be able to filter noise. For that, I’m using the MNIST dataset and injecting noise on it. The code runs fine, no error message is shown, but here comes the problem:
the code:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torchvision import models
import matplotlib.pyplot as plt
from PIL import Image
# class NoisyMNIST
class NoisyMNIST(MNIST):
def __init__(self, root, train=True, transform=None, noisy_transform=None, download=False):
super().__init__(root, train=train, transform=transform, download=download)
self.noisy_transform = noisy_transform
def __getitem__(self, index):
img = self.data[index].view(28, 28) # Remove the normalization and conversion to tensor
if self.transform is not None:
clean_img = self.transform(img)
if self.noisy_transform is not None:
noisy_img = self.noisy_transform(img)
return noisy_img, clean_img
#pre-trained resnet
resnet = models.resnet18(pretrained=True)
resnet = nn.Sequential(*list(resnet.children())[:-1])
resnet.add_module('upsample', nn.Upsample(size=(28, 28), mode='bilinear', align_corners=False))
resnet.add_module('leakyrelu', nn.LeakyReLU())
resnet.add_module('dropout', nn.Dropout(p=0.5))
resnet.add_module('last_conv', nn.Conv2d(512, 3, kernel_size=1))
torch.nn.init.kaiming_normal_(resnet.last_conv.weight)
# loss and optimizer
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(resnet.parameters(), lr=0.001)
# Load dataset and add gaussian noise to the images
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.float()), # Convert to FloatTensor
transforms.Normalize((0.5,), (0.5,)),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)) # repeating channel
])
noisy_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.float()), # Convert to FloatTensor
transforms.Normalize((0.5,), (0.5,)),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), # repeating channel
transforms.Lambda(lambda x: x + torch.randn_like(x)) # adding noise
])
train_set = NoisyMNIST('./data', train=True, transform=transform, noisy_transform=noisy_transform, download=True)
test_set = NoisyMNIST('./data', train=False, transform=transform, noisy_transform=noisy_transform, download=True)
# Split MNIST dataset
# train_set, valid_set = torch.utils.data.random_split(train_set, [50000, 10000])
# Reducing it to 10%
num_train = int(len(train_set) * 0.1)
num_valid = int(len(test_set) * 0.05) # Usamos metade para validação e metade para teste
num_test = num_valid
train_set, _ = random_split(train_set, [num_train, len(train_set) - num_train])
valid_set, _ = random_split(test_set, [num_valid, len(test_set) - num_valid])
test_set, _ = random_split(test_set, [num_test, len(test_set) - num_test])
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=64, shuffle=False)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
# To train the model
num_epochs = 10
train_losses = []
valid_losses = []
for i, data in enumerate(train_loader, 0):
inputs, targets = data
inputs = inputs.float()
targets = targets.float()
print(f"inputs shape: {inputs.shape}")
print(f"targets shape: {targets.shape}")
for epoch in range(num_epochs):
print(f"Training epoch {epoch+1}/{num_epochs}...")
running_loss = 0.0
# Training
for i, data in enumerate(train_loader, 0):
inputs, targets = data
inputs = inputs.float()
targets = targets.float()
optimizer.zero_grad()
outputs = resnet(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
# Validating
with torch.no_grad():
for i, data in enumerate(valid_loader, 0):
inputs, targets = data
inputs = inputs.float()
targets = targets.float()
outputs = resnet(inputs)
loss = criterion(outputs, targets)
valid_losses.append(loss.item())
train_losses.append(running_loss/len(train_loader))
print("Training completed!")
# Evaluate model
print("Evaluating model...")
test_losses = []
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
inputs, targets = data
inputs = inputs.float()
targets = targets.float()
outputs = resnet(inputs)
loss = criterion(outputs, targets)
test_losses.append(loss.item())
# Plot loss
plt.plot(train_losses, label='Treino')
plt.plot(valid_losses, label='Validação')
plt.legend()
plt.title('Perda')
plt.show()
# Save the trained model
#torch.save(resnet.state_dict(), './resnet_noise_removal.pth')
# Show the input images and the denoised images
plt.figure(figsize=(10, 5))
for i in range(5):
plt.subplot(2, 5, i+1)
plt.imshow(inputs[i].permute(1, 2, 0).numpy(), cmap='gray')
plt.title('Input')
plt.axis('off')
plt.subplot(2, 5, i+6)
plt.imshow(outputs[i].detach().permute(1, 2, 0).numpy(), cmap='gray')
plt.title('Denoised')
plt.axis('off')
plt.show()
It’s not returning the image without noise, it’s returning all black.
What am I doing wrong? What can I do to fix this?
Thanks!