I am trying to implement VAE with Convultions but for some reason it is not working but with feed forward network it works. I do not know the difference that is causing the error please help. thank you.
one thing to note is that the CNN-VAE loss never drops below 140 and seems to converge too early or at least that is what I’m seeing.
the FC VAE loss reaches around 100 and performs well.
for some reason the KLD does decrease below 40 while for the FCVAE is around 24
pytorch CNN VAE model
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
#%matplotlib inline
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
from IPython.display import Image
batch_size = 100 #- 500
epochs = 20
learning_rate = 1e-3
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
device = torch.device(dev)
Train = True
class Model(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=6, out_channels=12, kernel_size=3, stride=2, padding=1),
nn.ReLU()
)
#self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, stride=2, padding=1)
#self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv2d(in_channels=12, out_channels=16, kernel_size=3)
self.conv4 = nn.Conv2d(in_channels=12, out_channels=16, kernel_size=3)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(in_channels=16, out_channels=12, kernel_size=3),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=12, out_channels=6, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=6, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def passToEnoch(self, image):
t = image
t = self.encoder(t)
mu = self.conv3(t)
logvar = self.conv4(t)
return mu, logvar
def passToDenoch(self, BN):
t = BN
t = self.decoder(t)
GenImg = t
return GenImg
def Reparameterise(self, mean, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mean + eps*std
def forward(self, image):
mu, logvar = self.passToEnoch(image)
z = self.Reparameterise(mu, logvar)
construction = self.passToDenoch(z)
return construction, mu, logvar
train_set = torchvision.datasets.MNIST(
root='./data'
,train=True
,download=True
,transform=transforms.Compose([
transforms.ToTensor()
])
)
'''
train_loader = torch.utils.data.DataLoader(train_set, batch_size=len(train_set), shuffle=True)
data, label = next(iter(train_loader))
mean = data.mean()
std = data.std()
'''
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True)
network = Model()
#network = network.to(torch.device("cuda:0"))
optimizer = optim.Adam(network.parameters(), lr=0.0001)
def loss_function(pred, images, mu, logvar):
criterion = nn.BCELoss(reduction = 'sum')
#print(pred.shape)
reconstructionLoss = criterion(pred,images)
#reconstructionLoss = nn.functional.binary_cross_entropy(pred, images, reduction = 'sum')
#BCE = F.binary_cross_entropy(pred, images, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return reconstructionLoss ,KLD
for epoch in range(20):
total_loss_1 = 0
reconstruction_Loss = 0
KLD_loss_2 = 0
for batch in train_loader: # Get Batch
images, _ = batch
optimizer.zero_grad()
construction, mu, logvar = network(images) # Pass Batch
reconstructionLoss, KLD = loss_function(construction, images, mu, logvar) # Calculate Loss
loss = reconstructionLoss + KLD
loss.backward() # Calculate Gradients
optimizer.step() # Update Weights
total_loss_1 += loss.item()
reconstruction_Loss += reconstructionLoss.item()
KLD_loss_2 += KLD.item()
print("epoch", epoch, "total_loss_1:", total_loss_1/60000, "reconstruction_loss: ", reconstruction_Loss/60000,"KLD: ", KLD_loss_2/60000)
torch.save(network.state_dict(), 'VAE.pth')
with torch.no_grad():
network.eval()
o = np.random.normal(0,1, (100,16,5,5))
o = torch.from_numpy(o).float()
Train = False
#d = torch.distributions.Normal(0,1)
#o = d.sample((10,64,1,1))
#o = o.reshape(-1,64)
#o = o.to(torch.device("cuda:0"))
reconstruction = network.passToDenoch(o)
save_image(reconstruction,'Final.png')
Image('Final.png')
epoch 0 total_loss_1: 437.0815269856771 reconstruction_loss: 424.89729869791665 KLD: 12.184227733802796
epoch 1 total_loss_1: 271.06046875 reconstruction_loss: 235.04757037760416 KLD: 36.012898413085935
epoch 2 total_loss_1: 237.2292015625 reconstruction_loss: 201.87340013020832 KLD: 35.355801139322914
epoch 3 total_loss_1: 225.27503678385418 reconstruction_loss: 189.86380341796874 KLD: 35.411233211263024
epoch 4 total_loss_1: 218.279416015625 reconstruction_loss: 182.6045441731771 KLD: 35.674871655273435
epoch 5 total_loss_1: 212.743983203125 reconstruction_loss: 176.58763538411458 KLD: 36.156347798665365
epoch 6 total_loss_1: 208.0514568033854 reconstruction_loss: 171.18413247070313 KLD: 36.867324344889326
epoch 7 total_loss_1: 203.88413160807292 reconstruction_loss: 166.22725283203124 KLD: 37.65687854817708
epoch 8 total_loss_1: 200.00452884114583 reconstruction_loss: 161.81875375976563 KLD: 38.18577508951823
epoch 9 total_loss_1: 196.12084052734374 reconstruction_loss: 157.47882067057293 KLD: 38.642020076497396
epoch 10 total_loss_1: 191.95198984375 reconstruction_loss: 153.00707041015625 KLD: 38.94491872151693
epoch 11 total_loss_1: 187.28410458984376 reconstruction_loss: 148.25767250976563 KLD: 39.02643206787109
epoch 12 total_loss_1: 181.97289752604166 reconstruction_loss: 143.05000634765625 KLD: 38.92289149169922
epoch 13 total_loss_1: 176.001284375 reconstruction_loss: 137.28636323242188 KLD: 38.714921118164064
epoch 14 total_loss_1: 170.22259750976562 reconstruction_loss: 131.3868727701823 KLD: 38.83572508951823
epoch 15 total_loss_1: 164.65682986653647 reconstruction_loss: 125.59177838541666 KLD: 39.06505145670573
epoch 16 total_loss_1: 159.6399653483073 reconstruction_loss: 120.32449708658854 KLD: 39.31546833496094
epoch 17 total_loss_1: 155.5541121582031 reconstruction_loss: 115.93834047851563 KLD: 39.615771911621096
epoch 18 total_loss_1: 152.08300314127604 reconstruction_loss: 112.34369606119792 KLD: 39.73930708007813
epoch 19 total_loss_1: 149.19920162760417 reconstruction_loss: 109.36712736002605 KLD: 39.83207437337239
pytorch Feed forward VAE model
import numpy as np
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
#%matplotlib inline
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image
from IPython.display import Image
batch_size = 100 #- 500
epochs = 20
learning_rate = 1e-3
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
device = torch.device(dev)
Train = True
class Model(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(784,20**2),
nn.ReLU(),
nn.Linear(20**2,20*2)
)
self.fc_mu = nn.Linear(in_features=20*2, out_features=20)
self.fc_logvar = nn.Linear(in_features=20*2, out_features=20)
self.decoder = nn.Sequential(
nn.Linear(20,20**2),
nn.ReLU(),
nn.Linear(20**2,784),
nn.Sigmoid()
)
def passToEnoch(self, image):
bottleNeck = self.enoch(image)
return bottleNeck
def passToDenoch(self, BN):
GenImg = self.decoder(BN)
return GenImg
def Reparameterise(self, mean, logvar):
#std = logvar.mul(0.5).exp_()
#eps = std.data.new(std.size()).normal_()
#return eps.mul(std).add_(mean)
std = logvar.mul(0.5).exp_()
eps = std.data.new(std.size()).normal_()
#return eps * torch.exp(logvar * .5) + mean
return eps.mul(std).add_(mean)
def forward(self, image):
output = self.encoder(image.reshape(-1,784))
mu = self.fc_mu(output)
logvar = self.fc_logvar(output)
z = self.Reparameterise(mu, logvar)
construction = self.passToDenoch(z)
return construction, mu, logvar
train_set = torchvision.datasets.MNIST(
root='./data'
,train=True
,download=True
,transform=transforms.Compose([
transforms.ToTensor()
])
)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=len(train_set), shuffle=True)
data, label = next(iter(train_loader))
mean = data.mean()
std = data.std()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True)
network = Model()
#network = network.to(torch.device("cuda:0"))
optimizer = optim.Adam(network.parameters(), lr=0.001)
def loss_function(pred, images, mu, logvar):
criterion = nn.BCELoss(reduction = 'sum')
reconstructionLoss = criterion(pred,images.reshape(-1,784))
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return reconstructionLoss + KLD
for epoch in range(20):
total_loss_1 = 0
total_loss_2 = 0
for batch in train_loader: # Get Batch
images, _ = batch
#images = (images - mean) / (std + 1e-15)
#images = images.to(torch.device("cuda:0"))
construction, mu, logvar = network(images) # Pass Batch
loss = loss_function(construction, images, mu, logvar) # Calculate Loss
optimizer.zero_grad()
loss.backward() # Calculate Gradients
optimizer.step() # Update Weights
total_loss_1 += loss.item()
print("epoch", epoch, "total_loss_1:", total_loss_1/60000)
torch.save(network.state_dict(), 'fc_VAE.pth')
#model = torch.load('VAE.pth')
with torch.no_grad():
network.eval()
o = np.random.normal(0,1, (100,20))
o = torch.from_numpy(o).float()
Train = False
reconstruction = network.passToDenoch(o).reshape(100,1,28,28)
save_image(reconstruction,'Final_fc.png')
Image('Final_fc.png')
epoch 0 total_loss_1: 164.42614547526043 reconstruction_loss: 149.9676513346354 KLD: 14.458494481976826
epoch 1 total_loss_1: 124.39563199869792 reconstruction_loss: 103.829515234375 KLD: 20.566116786702473
epoch 2 total_loss_1: 115.76229205729166 reconstruction_loss: 93.75830346679687 KLD: 22.00398865559896
epoch 3 total_loss_1: 112.11807252604167 reconstruction_loss: 89.46627601725261 KLD: 22.65179629313151
epoch 4 total_loss_1: 110.07662482096354 reconstruction_loss: 86.96918548177084 KLD: 23.107439229329426
epoch 5 total_loss_1: 108.80403064778646 reconstruction_loss: 85.46129479980469 KLD: 23.34273593343099
epoch 6 total_loss_1: 107.87016588541667 reconstruction_loss: 84.34400255533855 KLD: 23.526163228352864
epoch 7 total_loss_1: 107.15195069986979 reconstruction_loss: 83.49743802897136 KLD: 23.654512723795573
epoch 8 total_loss_1: 106.5942380045573 reconstruction_loss: 82.87207277832032 KLD: 23.722165393066405
epoch 9 total_loss_1: 106.19866038411459 reconstruction_loss: 82.36383403320312 KLD: 23.834826477050783
epoch 10 total_loss_1: 105.81030188802083 reconstruction_loss: 81.90263583170572 KLD: 23.907666267903647
epoch 11 total_loss_1: 105.47002320963541 reconstruction_loss: 81.53675849609375 KLD: 23.93326481933594
epoch 12 total_loss_1: 105.20676300455729 reconstruction_loss: 81.21078239746093 KLD: 23.995980533854166
epoch 13 total_loss_1: 104.91488715820313 reconstruction_loss: 80.9069107421875 KLD: 24.007976521809894
epoch 14 total_loss_1: 104.69429993489584 reconstruction_loss: 80.63368645019531 KLD: 24.06061327311198
epoch 15 total_loss_1: 104.45741033528645 reconstruction_loss: 80.39678291015625 KLD: 24.060627416992187
epoch 16 total_loss_1: 104.2428307454427 reconstruction_loss: 80.16855849609375 KLD: 24.07427227783203
epoch 17 total_loss_1: 104.09353636067708 reconstruction_loss: 79.99055166015626 KLD: 24.10298465576172
epoch 18 total_loss_1: 103.96265613606771 reconstruction_loss: 79.81093037109375 KLD: 24.151725834147136
epoch 19 total_loss_1: 103.76769731445313 reconstruction_loss: 79.62591543782553 KLD: 24.141781754557293