Autoencoder model testing and modifications for hallucination

Hey! I am currently implementing an autoencoder with a simple MNIST dataset. My goals are as follows:

  1. Autoencode the images using the network. 28x28 input and 28x28 output
  2. Test the Model with an input image, and see what the output would be
  3. Finally, I’m looking for it to have a broken image as TRAINING input. I manually break it by covering it with a grey square in the middle. So basically, the network never has seen the MNIST files before corruption, but it should try to hallucinate the image before corruption.

Of course point 3 is very specific, but the first issue I’m currently facing is testing the model itself.

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable

def to_img(x):
    x = x.view(x.size(0), 1, 28, 28)
    return x

def plot_sample_img(img, name):
  img = img.view(1, 28, 28)
  save_image(img, './sample_{}.png'.format(name))

class AE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=128
        )
        self.encoder_output_layer = nn.Linear(
            in_features=128, out_features=128
        )
        self.decoder_hidden_layer = nn.Linear(
            in_features=128, out_features=128
        )
        self.decoder_output_layer = nn.Linear(
            in_features=128, out_features=kwargs["input_shape"]
        )

    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(activation)
        code = self.encoder_output_layer(activation)
        code = torch.relu(code)
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.relu(activation)
        return reconstructed

#  use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# create a model from `AE` autoencoder class
# load it to the specified device, either gpu or cpu
model = AE(input_shape=784).to(device)

# create an optimizer object
# Adam optimizer with learning rate 1e-3

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# mean-squared error loss
criterion = nn.MSELoss()

def erase_middle(image: torch.Tensor) -> torch.Tensor:
    _, height, width = image.size()
    x_start = width // 2 - 5
    x_end = width // 2 + 5
    y_start = height // 2 - 5
    y_end = height // 2 + 5
    # Using slices achieves the same as the for loops
    image[:, y_start:y_end, x_start:x_end] = 0.5
    return image

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Lambda(erase_middle)])

train_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=True, transform=transform, download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=False, num_workers=4
)

def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))



dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))

epochs = 10
for epoch in range(epochs):
    loss = 0
    for batch_features, _ in train_loader:
        # reshape mini-batch data to [N, 784] matrix
        # load it to the active device
        batch_features = batch_features.view(-1, 784).to(device)
        
        # reset the gradients back to zero
        # PyTorch accumulates gradients on subsequent backward passes
        optimizer.zero_grad()
        
        # compute reconstructions
        outputs = model(batch_features)
        
        # compute training reconstruction loss
        train_loss = criterion(outputs, batch_features)
        
        # compute accumulated gradients
        train_loss.backward()
        
        # perform parameter update based on current gradients
        optimizer.step()
        
        # add the mini-batch training loss to epoch loss
        loss += train_loss.item()

        # saving per iter to allow for resume
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, "/content/model_saved")
    
    # compute the epoch training loss
    loss = loss / len(train_loader)
    
    # display the epoch training loss
    print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))

model.eval()
for batch_features,_ in test_loader:
    plt.figure()
    batch_features = batch_features.view(-1, 784).to(device)
    pred = model(Variable(batch_features))


imshow(torchvision.utils.make_grid(pred[0].cpu().detach()[0]))

The last few lines, was my attempt at plotting the images. Could someone guide me how to plot the image from this? Ideally it would show the input and output images next to each other.

And how would I use the grey square in the middle to attempt to “predict” what has been there before? As far as I know, it is really hard to predict something the neural net has never seen, as it could not learn about it before.

Any help would be appreciated!

Are you currently stuck at a specific point in your code?
E.g. does imshow(make_grid(...)) not work or would you like to change the layout somehow?