CNN VAE not working While Feed Forward VAE works fine

I am trying to implement VAE with Convultions but for some reason it is not working but with feed forward network it works. I do not know the difference that is causing the error please help. thank you.

one thing to note is that the CNN-VAE loss never drops below 140 and seems to converge too early or at least that is what I’m seeing.
the FC VAE loss reaches around 100 and performs well.
for some reason the KLD does decrease below 40 while for the FCVAE is around 24

pytorch CNN VAE model


import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
#%matplotlib inline
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
from IPython.display import Image

batch_size = 100 #- 500
epochs = 20
learning_rate = 1e-3

if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu" 

device = torch.device(dev)

Train = True

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=6, out_channels=12, kernel_size=3, stride=2, padding=1),
        nn.ReLU()
        )

        #self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, stride=2, padding=1)
        #self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=12, out_channels=16, kernel_size=3)
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=16, kernel_size=3)


        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=16, out_channels=12, kernel_size=3),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=12, out_channels=6, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=6, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )


    def passToEnoch(self, image):

        t = image

        t = self.encoder(t)
        mu = self.conv3(t)
        logvar = self.conv4(t)

        return mu, logvar

    def passToDenoch(self, BN):

        t = BN

        t = self.decoder(t)
        GenImg = t

        return GenImg
    
    def Reparameterise(self, mean, logvar):
       
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mean + eps*std

    def forward(self, image):

        mu, logvar = self.passToEnoch(image)

        z = self.Reparameterise(mu, logvar)

        construction = self.passToDenoch(z)
        
        return construction, mu, logvar

train_set = torchvision.datasets.MNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)
'''

train_loader = torch.utils.data.DataLoader(train_set, batch_size=len(train_set), shuffle=True)
data, label = next(iter(train_loader))
mean = data.mean()
std = data.std()
'''


train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True)

network = Model()
#network = network.to(torch.device("cuda:0"))

optimizer = optim.Adam(network.parameters(), lr=0.0001)


def loss_function(pred, images, mu, logvar):

    criterion = nn.BCELoss(reduction = 'sum')
    #print(pred.shape)
    reconstructionLoss = criterion(pred,images)
    #reconstructionLoss = nn.functional.binary_cross_entropy(pred, images, reduction = 'sum')
    #BCE = F.binary_cross_entropy(pred, images, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstructionLoss ,KLD



for epoch in range(20):

    total_loss_1 = 0
    reconstruction_Loss = 0
    KLD_loss_2 = 0


    for batch in train_loader: # Get Batch
        
        images, _ = batch
        
        optimizer.zero_grad()
        construction, mu, logvar = network(images) # Pass Batch
        reconstructionLoss, KLD = loss_function(construction, images, mu, logvar) # Calculate Loss
        loss  = reconstructionLoss + KLD
        
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss_1 += loss.item()
        reconstruction_Loss += reconstructionLoss.item()
        KLD_loss_2 += KLD.item()

    print("epoch", epoch, "total_loss_1:", total_loss_1/60000, "reconstruction_loss: ", reconstruction_Loss/60000,"KLD: ", KLD_loss_2/60000)

torch.save(network.state_dict(), 'VAE.pth')


with torch.no_grad():
    network.eval()

    o = np.random.normal(0,1, (100,16,5,5))
    o = torch.from_numpy(o).float()
    Train = False
    #d = torch.distributions.Normal(0,1)
    #o = d.sample((10,64,1,1))  
    #o = o.reshape(-1,64)
    #o = o.to(torch.device("cuda:0"))
    reconstruction = network.passToDenoch(o)

save_image(reconstruction,'Final.png')
Image('Final.png') 

epoch 0 total_loss_1: 437.0815269856771 reconstruction_loss:  424.89729869791665 KLD:  12.184227733802796
epoch 1 total_loss_1: 271.06046875 reconstruction_loss:  235.04757037760416 KLD:  36.012898413085935
epoch 2 total_loss_1: 237.2292015625 reconstruction_loss:  201.87340013020832 KLD:  35.355801139322914
epoch 3 total_loss_1: 225.27503678385418 reconstruction_loss:  189.86380341796874 KLD:  35.411233211263024
epoch 4 total_loss_1: 218.279416015625 reconstruction_loss:  182.6045441731771 KLD:  35.674871655273435
epoch 5 total_loss_1: 212.743983203125 reconstruction_loss:  176.58763538411458 KLD:  36.156347798665365
epoch 6 total_loss_1: 208.0514568033854 reconstruction_loss:  171.18413247070313 KLD:  36.867324344889326
epoch 7 total_loss_1: 203.88413160807292 reconstruction_loss:  166.22725283203124 KLD:  37.65687854817708
epoch 8 total_loss_1: 200.00452884114583 reconstruction_loss:  161.81875375976563 KLD:  38.18577508951823
epoch 9 total_loss_1: 196.12084052734374 reconstruction_loss:  157.47882067057293 KLD:  38.642020076497396
epoch 10 total_loss_1: 191.95198984375 reconstruction_loss:  153.00707041015625 KLD:  38.94491872151693
epoch 11 total_loss_1: 187.28410458984376 reconstruction_loss:  148.25767250976563 KLD:  39.02643206787109
epoch 12 total_loss_1: 181.97289752604166 reconstruction_loss:  143.05000634765625 KLD:  38.92289149169922
epoch 13 total_loss_1: 176.001284375 reconstruction_loss:  137.28636323242188 KLD:  38.714921118164064
epoch 14 total_loss_1: 170.22259750976562 reconstruction_loss:  131.3868727701823 KLD:  38.83572508951823
epoch 15 total_loss_1: 164.65682986653647 reconstruction_loss:  125.59177838541666 KLD:  39.06505145670573
epoch 16 total_loss_1: 159.6399653483073 reconstruction_loss:  120.32449708658854 KLD:  39.31546833496094
epoch 17 total_loss_1: 155.5541121582031 reconstruction_loss:  115.93834047851563 KLD:  39.615771911621096
epoch 18 total_loss_1: 152.08300314127604 reconstruction_loss:  112.34369606119792 KLD:  39.73930708007813
epoch 19 total_loss_1: 149.19920162760417 reconstruction_loss:  109.36712736002605 KLD:  39.83207437337239

Final
pytorch Feed forward VAE model

import numpy as np

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
#%matplotlib inline
import torch.nn as nn
import torch.nn.functional as F

from torchvision.utils import save_image
from IPython.display import Image

batch_size = 100 #- 500
epochs = 20
learning_rate = 1e-3

if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu" 

device = torch.device(dev)

Train = True

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(784,20**2),
            nn.ReLU(),
            nn.Linear(20**2,20*2)
        )

        self.fc_mu = nn.Linear(in_features=20*2, out_features=20)
        self.fc_logvar = nn.Linear(in_features=20*2, out_features=20)

        self.decoder = nn.Sequential(
            nn.Linear(20,20**2),
            nn.ReLU(),
            nn.Linear(20**2,784),
            nn.Sigmoid()
        )

    def passToEnoch(self, image):

        bottleNeck = self.enoch(image)

        return bottleNeck

    def passToDenoch(self, BN):

        GenImg = self.decoder(BN)

        return GenImg
    
    def Reparameterise(self, mean, logvar):
       
        #std = logvar.mul(0.5).exp_()
        #eps = std.data.new(std.size()).normal_()
        #return eps.mul(std).add_(mean)
        
        std = logvar.mul(0.5).exp_()
        eps = std.data.new(std.size()).normal_()
        #return eps * torch.exp(logvar * .5) + mean
        return eps.mul(std).add_(mean)
        

    def forward(self, image):

        output = self.encoder(image.reshape(-1,784))

        mu = self.fc_mu(output)
        logvar =  self.fc_logvar(output)

        z = self.Reparameterise(mu, logvar)
        
        construction = self.passToDenoch(z)
        
        return construction, mu, logvar

train_set = torchvision.datasets.MNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)


train_loader = torch.utils.data.DataLoader(train_set, batch_size=len(train_set), shuffle=True)
data, label = next(iter(train_loader))
mean = data.mean()
std = data.std()



train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True)

network = Model()
#network = network.to(torch.device("cuda:0"))

optimizer = optim.Adam(network.parameters(), lr=0.001)


def loss_function(pred, images, mu, logvar):

    criterion = nn.BCELoss(reduction = 'sum')
    
    reconstructionLoss = criterion(pred,images.reshape(-1,784))

    KLD  = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return reconstructionLoss + KLD



for epoch in range(20):

    total_loss_1 = 0
    total_loss_2 = 0


    for batch in train_loader: # Get Batch
        
        images, _ = batch
        #images = (images - mean) / (std + 1e-15)
        #images = images.to(torch.device("cuda:0"))
        
        construction, mu, logvar = network(images) # Pass Batch
        loss = loss_function(construction, images, mu, logvar) # Calculate Loss

        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights

        total_loss_1 += loss.item()

    print("epoch", epoch, "total_loss_1:", total_loss_1/60000)

torch.save(network.state_dict(), 'fc_VAE.pth')

#model = torch.load('VAE.pth')


with torch.no_grad():
    network.eval()

    o = np.random.normal(0,1, (100,20))
    o = torch.from_numpy(o).float()
    Train = False
    
    reconstruction = network.passToDenoch(o).reshape(100,1,28,28)

    save_image(reconstruction,'Final_fc.png')
    Image('Final_fc.png') 
epoch 0 total_loss_1: 164.42614547526043 reconstruction_loss:  149.9676513346354 KLD:  14.458494481976826
epoch 1 total_loss_1: 124.39563199869792 reconstruction_loss:  103.829515234375 KLD:  20.566116786702473
epoch 2 total_loss_1: 115.76229205729166 reconstruction_loss:  93.75830346679687 KLD:  22.00398865559896
epoch 3 total_loss_1: 112.11807252604167 reconstruction_loss:  89.46627601725261 KLD:  22.65179629313151
epoch 4 total_loss_1: 110.07662482096354 reconstruction_loss:  86.96918548177084 KLD:  23.107439229329426
epoch 5 total_loss_1: 108.80403064778646 reconstruction_loss:  85.46129479980469 KLD:  23.34273593343099
epoch 6 total_loss_1: 107.87016588541667 reconstruction_loss:  84.34400255533855 KLD:  23.526163228352864
epoch 7 total_loss_1: 107.15195069986979 reconstruction_loss:  83.49743802897136 KLD:  23.654512723795573
epoch 8 total_loss_1: 106.5942380045573 reconstruction_loss:  82.87207277832032 KLD:  23.722165393066405
epoch 9 total_loss_1: 106.19866038411459 reconstruction_loss:  82.36383403320312 KLD:  23.834826477050783
epoch 10 total_loss_1: 105.81030188802083 reconstruction_loss:  81.90263583170572 KLD:  23.907666267903647
epoch 11 total_loss_1: 105.47002320963541 reconstruction_loss:  81.53675849609375 KLD:  23.93326481933594
epoch 12 total_loss_1: 105.20676300455729 reconstruction_loss:  81.21078239746093 KLD:  23.995980533854166
epoch 13 total_loss_1: 104.91488715820313 reconstruction_loss:  80.9069107421875 KLD:  24.007976521809894
epoch 14 total_loss_1: 104.69429993489584 reconstruction_loss:  80.63368645019531 KLD:  24.06061327311198
epoch 15 total_loss_1: 104.45741033528645 reconstruction_loss:  80.39678291015625 KLD:  24.060627416992187
epoch 16 total_loss_1: 104.2428307454427 reconstruction_loss:  80.16855849609375 KLD:  24.07427227783203
epoch 17 total_loss_1: 104.09353636067708 reconstruction_loss:  79.99055166015626 KLD:  24.10298465576172
epoch 18 total_loss_1: 103.96265613606771 reconstruction_loss:  79.81093037109375 KLD:  24.151725834147136
epoch 19 total_loss_1: 103.76769731445313 reconstruction_loss:  79.62591543782553 KLD:  24.141781754557293

Your CNN is too small (<1000 parameters, e.g. second layer trains 12*6*3*3+12 scalars). Compare this to 400x784 matrices used in FC network. Additionally, mu and logvar should probably be generated using linear layers.

Thank you for your response.
Do you think I should increase the number of channels or have additional convolutions.

Also, why should mu and logvar be generated with linear layers and not convolutions?

Try network configurations like for example here

Convolution outputs are location invariant, if all features are allowed to interact, more expressive models are possible.

Even after increasing the number of parameters and becoming egregiously slow , it still is unable to generate images, I m assuming its overfitting but I cant say for certain. even though it reach loss levels similar to the fully connected model it doesn’t perform as good.

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=24, kernel_size=3, stride=2, padding=1),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=24, out_channels=48, kernel_size=3, stride=2, padding=1),
        nn.LeakyReLU(),
        nn.Conv2d(in_channels=48, out_channels=96, kernel_size=3)
        )

        self.fc1 = nn.Linear(in_features=96*5*5, out_features=1200)
        self.fc2 = nn.Linear(in_features=1200, out_features=600)
        self.fc_mu = nn.Linear(in_features=600, out_features=20)
        self.fc_logvar = nn.Linear(in_features=600, out_features=20)

        self.conv5 =  nn.Linear(in_features=20, out_features=96*5*5) 


        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=96, out_channels=48, kernel_size=3),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(in_channels=48, out_channels=24, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(in_channels=24, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )


    def passToEnoch(self, image):

        t = image

        t = self.encoder(t)
        t = self.fc1(t.reshape(100,-1))
        t = self.fc2(t)
        mu = self.fc_mu(t)
        logvar = self.fc_logvar(t)

        return mu, logvar

    def passToDenoch(self, BN):

        t = BN
       
        t = t.reshape(100,96,5,5)
        t = self.decoder(t)
        GenImg = t

        return GenImg
    
    def Reparameterise(self, mean, logvar):
       
        std = logvar.mul(0.5).exp_()
        eps = std.data.new(std.size()).normal_()
        return eps.mul(std).add_(mean)
        

    def forward(self, image):

        mu, logvar = self.passToEnoch(image)

        
        z = self.Reparameterise(mu, logvar)
        z = self.conv5(z)

        construction = self.passToDenoch(z)
        
        return construction, mu, logvar
epoch 0 total_loss_1: 254.0756600423177 reconstruction_loss:  231.1596802246094 KLD:  22.915979824928442
epoch 1 total_loss_1: 141.40640200195313 reconstruction_loss:  116.31750934244792 KLD:  25.08889241129557
epoch 2 total_loss_1: 126.85437965494792 reconstruction_loss:  100.08900904947917 KLD:  26.765370650227865
epoch 3 total_loss_1: 120.90743160807291 reconstruction_loss:  93.52300227864583 KLD:  27.38442930094401
epoch 4 total_loss_1: 117.74371357421875 reconstruction_loss:  90.27648059895833 KLD:  27.467232918294272
epoch 5 total_loss_1: 115.44130123697917 reconstruction_loss:  87.98953850911458 KLD:  27.451762455240885
epoch 6 total_loss_1: 113.64397905273438 reconstruction_loss:  86.25507771809896 KLD:  27.388901236979166
epoch 7 total_loss_1: 112.21055922851562 reconstruction_loss:  84.92787290852864 KLD:  27.282686189778644
epoch 8 total_loss_1: 111.06344233398437 reconstruction_loss:  83.91212922363282 KLD:  27.15131326904297
epoch 9 total_loss_1: 110.11857666015625 reconstruction_loss:  83.0612642171224 KLD:  27.05731242675781
epoch 10 total_loss_1: 109.25627236328125 reconstruction_loss:  82.33233040364583 KLD:  26.92394188639323
epoch 11 total_loss_1: 108.61756567382812 reconstruction_loss:  81.75407697753906 KLD:  26.863488631184897
epoch 12 total_loss_1: 108.02086090494792 reconstruction_loss:  81.28328525390624 KLD:  26.737575720214842
epoch 13 total_loss_1: 107.53491346028646 reconstruction_loss:  80.84818310546875 KLD:  26.686730619303386
epoch 14 total_loss_1: 107.04162428385416 reconstruction_loss:  80.43439178873697 KLD:  26.607232442220052
epoch 15 total_loss_1: 106.61778201497395 reconstruction_loss:  80.10782824707032 KLD:  26.509954032389324
epoch 16 total_loss_1: 106.26496030273438 reconstruction_loss:  79.8152283203125 KLD:  26.449732063802085
epoch 17 total_loss_1: 105.94519632161459 reconstruction_loss:  79.5302062906901 KLD:  26.414990108235678
epoch 18 total_loss_1: 105.67471401367187 reconstruction_loss:  79.33153162434895 KLD:  26.343182397460936
epoch 19 total_loss_1: 105.33589560546875 reconstruction_loss:  79.04250388183594 KLD:  26.293391638183593

download

Now your generative routine is incorrect, because it doesn’t use “conv5” linear layer, that decodes gaussian variates.

Here is my output. I removed fc1,fc2 layers, that seem unneeded, esp. without activations. Probably lr increase/scheduling would improve results. As for speed, it is pretty bad without a gpu, yeah.

Final

Thank you it worked, it had to with the linear layer that decodes gaussian variates.