Variational autoencoder: the same reconstructed images

Hi, I want to check how the VAE reconstructs the image, but for some reason, I got the same images for different inputs. I would be very happy if you will tell me where I made the mistake.

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
                nn.Conv2d(1, 32, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Flatten()
        )
        self.z_mean = torch.nn.Linear(3136, 2)
        self.z_log_var = torch.nn.Linear(3136, 2)
 

        self.decoder = nn.Sequential(
                torch.nn.Linear(2, 3136),
                Reshape(-1, 64, 7, 7),
                nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 32, stride=(2, 2), kernel_size=(3, 3), padding=0), 
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(32, 1, stride=(1, 1), kernel_size=(3, 3), padding=0), 
                Trim(),  # 1x29x29 -> 1x28x28
                nn.Sigmoid()
                )
 
    def gaussian_sampler(self, mu, logsigma):
        if self.training:
            eps = torch.randn(mu.size(0), mu.size(1)).to(device)
            z = mu + eps * torch.exp(logsigma/2.) 
            return z

        else:
            return mu
    
    def forward(self, x):
        mu, logsigma = self.z_mean(self.encoder(x)), self.z_log_var(self.encoder(x))
        reconstruction = self.decoder(self.gaussian_sampler(mu, logsigma))

        return mu, logsigma, reconstruction
def KL_divergence(mu, logsigma):

    loss = - 0.5 * torch.mean(1 + logsigma - mu ** 2 - logsigma.exp()) 
    return loss

def log_likelihood(x, reconstruction):

    loss = nn.BCELoss(reduction = 'mean')#<binary cross-entropy>
    return loss(reconstruction, x)

def loss_vae(x, mu, logsigma, reconstruction):
    return KL_divergence(mu, logsigma) + log_likelihood(x, reconstruction)

n_epochs = 25
train_losses = []
val_losses = []

for epoch in tqdm(range(n_epochs)):
    autoencoder.train()
    train_losses_per_epoch = []
    for batch in train_loader:
        optimizer.zero_grad()
        mu, logsigma, reconstruction = autoencoder(torch.tensor(batch[0]).to(device))
        #print(reconstruction.shape, mu.shape, logsigma.shape, end = '\n' )
        #reconstruction = reconstruction.view(-1, 28, 28, 3)
        loss = criterion(batch[0].to(device).float(), mu, logsigma, reconstruction)
        loss.backward()
        optimizer.step()
        train_losses_per_epoch.append(loss.item())

    train_losses.append(np.mean(train_losses_per_epoch))

    autoencoder.eval()
    val_losses_per_epoch = []
    with torch.no_grad():
        for batch in val_loader:
            mu, logsigma, reconstruction = autoencoder(torch.tensor(batch[0]).to(device))
            #reconstruction = reconstruction.view(-1, 64, 64, 3)
            loss = criterion(batch[0].to(device).float(), mu, logsigma, reconstruction)
            val_losses_per_epoch.append(loss.item())

    val_losses.append(np.mean(val_losses_per_epoch))



result_val = []
ground_truth_val = []
autoencoder.eval()
with torch.no_grad():
    for batch in val_loader:
        print(batch[0].shape)
        mu, logsigma, reconstruction = autoencoder(batch[0].to(device))
        print(reconstruction.shape)
        #print( autoencoder.gaussian_sampler(mu, logsigma))
        #reconstruction = reconstruction.view(-1, 64, 64, 3)
        result = reconstruction.cpu().detach().numpy()
        ground_truth = batch[0].numpy()
        result_val.extend(result)
        ground_truth_val.extend(ground_truth)
        break


import matplotlib.pyplot as plt


plt.figure(figsize=(8, 20))
for i, (gt, res) in enumerate(zip(ground_truth_val[:10], result_val[:10])):
    plt.subplot(10, 2, 2*i+1)
    plt.imshow(gt[0])
    plt.subplot(10, 2, 2*i+2)
    plt.imshow(res[0])