Hi, I want to check how the VAE reconstructs the image, but for some reason, I got the same images for different inputs. I would be very happy if you will tell me where I made the mistake.
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, stride=(1, 1), kernel_size=(3, 3), padding=1),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
nn.LeakyReLU(0.01),
nn.Conv2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
nn.LeakyReLU(0.01),
nn.Conv2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
nn.LeakyReLU(0.01),
nn.Flatten()
)
self.z_mean = torch.nn.Linear(3136, 2)
self.z_log_var = torch.nn.Linear(3136, 2)
self.decoder = nn.Sequential(
torch.nn.Linear(2, 3136),
Reshape(-1, 64, 7, 7),
nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
nn.LeakyReLU(0.01),
nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
nn.LeakyReLU(0.01),
nn.ConvTranspose2d(64, 32, stride=(2, 2), kernel_size=(3, 3), padding=0),
nn.LeakyReLU(0.01),
nn.ConvTranspose2d(32, 1, stride=(1, 1), kernel_size=(3, 3), padding=0),
Trim(), # 1x29x29 -> 1x28x28
nn.Sigmoid()
)
def gaussian_sampler(self, mu, logsigma):
if self.training:
eps = torch.randn(mu.size(0), mu.size(1)).to(device)
z = mu + eps * torch.exp(logsigma/2.)
return z
else:
return mu
def forward(self, x):
mu, logsigma = self.z_mean(self.encoder(x)), self.z_log_var(self.encoder(x))
reconstruction = self.decoder(self.gaussian_sampler(mu, logsigma))
return mu, logsigma, reconstruction
def KL_divergence(mu, logsigma):
loss = - 0.5 * torch.mean(1 + logsigma - mu ** 2 - logsigma.exp())
return loss
def log_likelihood(x, reconstruction):
loss = nn.BCELoss(reduction = 'mean')#<binary cross-entropy>
return loss(reconstruction, x)
def loss_vae(x, mu, logsigma, reconstruction):
return KL_divergence(mu, logsigma) + log_likelihood(x, reconstruction)
n_epochs = 25
train_losses = []
val_losses = []
for epoch in tqdm(range(n_epochs)):
autoencoder.train()
train_losses_per_epoch = []
for batch in train_loader:
optimizer.zero_grad()
mu, logsigma, reconstruction = autoencoder(torch.tensor(batch[0]).to(device))
#print(reconstruction.shape, mu.shape, logsigma.shape, end = '\n' )
#reconstruction = reconstruction.view(-1, 28, 28, 3)
loss = criterion(batch[0].to(device).float(), mu, logsigma, reconstruction)
loss.backward()
optimizer.step()
train_losses_per_epoch.append(loss.item())
train_losses.append(np.mean(train_losses_per_epoch))
autoencoder.eval()
val_losses_per_epoch = []
with torch.no_grad():
for batch in val_loader:
mu, logsigma, reconstruction = autoencoder(torch.tensor(batch[0]).to(device))
#reconstruction = reconstruction.view(-1, 64, 64, 3)
loss = criterion(batch[0].to(device).float(), mu, logsigma, reconstruction)
val_losses_per_epoch.append(loss.item())
val_losses.append(np.mean(val_losses_per_epoch))
result_val = []
ground_truth_val = []
autoencoder.eval()
with torch.no_grad():
for batch in val_loader:
print(batch[0].shape)
mu, logsigma, reconstruction = autoencoder(batch[0].to(device))
print(reconstruction.shape)
#print( autoencoder.gaussian_sampler(mu, logsigma))
#reconstruction = reconstruction.view(-1, 64, 64, 3)
result = reconstruction.cpu().detach().numpy()
ground_truth = batch[0].numpy()
result_val.extend(result)
ground_truth_val.extend(ground_truth)
break
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 20))
for i, (gt, res) in enumerate(zip(ground_truth_val[:10], result_val[:10])):
plt.subplot(10, 2, 2*i+1)
plt.imshow(gt[0])
plt.subplot(10, 2, 2*i+2)
plt.imshow(res[0])