NOOB needs help with VAE encoder [training loss not declining] (help)

import torch
from torch import nn 
import torch.distributions as dist
import numpy as np
from PIL import Image
import random
from keras.datasets import mnist
import pandas as pd
import torch.nn.functional as F
import tensorflow as tf
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, TensorDataset, DataLoader


"""
VAE (Auto Encoder)
"""

#Does not have a activation function on output because we want to perserve the data relation and not limit it to sigmoid
class VAE(nn.Module):

    def __init__(self, state_dim=784, action_dim=32, hidden_dim=400):
        super(VAE, self).__init__()
        
        
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.mu = nn.Linear(hidden_dim, action_dim)
        self.log_var = nn.Linear(hidden_dim, action_dim)

        self.fc2 = nn.Linear(action_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, state_dim)
        
        
    def encoder(self, x):
        x = F.relu(self.fc1(x))
        mu = self.mu(x)
        log_var = self.log_var(x)
        return mu, log_var  


      
    def reparameterization(self,mu,log_var):
      std = torch.exp(0.5 * log_var)
      eps = torch.randn_like(std)
      z = mu + eps*std
      return z


    def decoder(self, z):
      x = F.relu(self.fc2(z))
      pred = torch.sigmoid(self.fc3(x))
      
      return pred



    def forward(self,x):
      mu, log_var = self.encoder(x)
      z = self.reparameterization(mu,log_var)
      recon_x = self.decoder(z)
      return recon_x, mu, log_var



    def loss_calculation(self, log_var, mu, recon_x, original_x, beta=.7, batch_size=128):
      recon_loss = F.mse_loss(recon_x, original_x, reduction='sum') / batch_size
      kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_size
      total_loss = recon_loss + beta * kl
      return recon_loss, kl, total_loss








#Data preperation

(train_X, train_y), (test_X, test_y) = mnist.load_data()

train_X = train_X.reshape(60000, 784).astype('float32') / 255.0



train_tensor = torch.tensor(train_X, dtype=torch.float32) # Changed this line

dataset = TensorDataset(train_tensor)

batch_size = 128

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)






vae = VAE()
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)



for epoch in range(1000):  # num_epochs
    total_loss = 0
    num_batches = 0

    for batch in train_loader:
        x = batch[0]  

        recon_x, mu, log_var = vae.forward(x)
        # plt.imshow(recon_x[0].detach().numpy().reshape(28,28))
        # plt.show()
        recon_loss, kl, loss = vae.loss_calculation(log_var, mu, recon_x, x, batch_size)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    
    print(f"Epoch {epoch+1}: Loss = {avg_loss:.2f}")

Epoch 1: Loss = 55.86
Epoch 2: Loss = 53.12
Epoch 3: Loss = 53.02
Epoch 4: Loss = 52.97
Epoch 5: Loss = 52.92
Epoch 6: Loss = 52.91
Epoch 7: Loss = 52.87
Epoch 8: Loss = 52.85
Epoch 9: Loss = 52.84
Epoch 10: Loss = 52.81
Epoch 11: Loss = 52.80
Epoch 12: Loss = 52.79


I tried to change the action layer from 2 → 32 → 64 but that did not work.

I do not know why my loss is barely decreasing.

I suspect that this is an architectural issue, but I am not sure as I have tried different types