Why is my Auto-Encoder not learning the distribution of FMNIST images?

I am using a simple autoencoder to learn images from the FashionMnist dataset. I have preprocessed the dataset by grayscaling and normalizing it. I did not make the network too deep, to prevent it from creating a direct mapping.

Here’s my PyTorch code -

import torch
import torchvision as tv
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch import nn
import os
from torchviz import make_dot
transforms = tv.transforms.Compose([tv.transforms.Grayscale(num_output_channels=1)])
trainset = tv.datasets.FashionMNIST(root='./data', train=True,
                                        download=True, transform=transforms)
PATH = './ae.pth'
data = trainset.data.float()
data = data/255
# print(trainset.data.shape)
plt.imshow(trainset.data[0], cmap = 'gray')
plt.show()



class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.encode = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 30),
            nn.ReLU()
        )
        self.decode = nn.Sequential(
            nn.Linear(30, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.flatten(x)
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return decoded
if(os.path.exists(PATH)):
    print("Loading data on cpu")
    device = torch.device('cpu')
    model = NeuralNetwork()
    model.load_state_dict(torch.load(PATH, map_location=device))

else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data = data.to(device)
    print(f"Using device = {device}")
    model = NeuralNetwork().to(device)
    # print(model)

    lossFn  = nn.BCELoss()

    optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)

    for epoch in range(1000):
        print("Epoch = ", epoch)
        optimizer.zero_grad()
        outputs = model(data)
        loss = lossFn(outputs, data.reshape(-1, 784))
        loss.backward()
        optimizer.step()

    torch.save(model.state_dict(), PATH)
    data = data.to("cpu")
    model = model.to("cpu")

pred = model(data)
pred = pred.reshape(-1, 28, 28)
# print(pred.shape)
plt.imshow(pred.detach().numpy()[0], cmap = 'gray')
plt.show()

For testing, I am inputting the following image -

The following image gets outputted -

Hi.
When your model cant re-create input images, that means it can’t learn to extract features of image or decode them.
I suggest you using layers like Convolution and Deconvolution.

1 Like