Diffusion model loss not working properly

Hey everybody!
I am currently trying to implement a diffusion model using the UNET architecture in order to generate images based on labels. Since I am very new to the subject I started training it with the MNIST data set, since the images are of simpler kind.
The issue I am facing is the fact that the average loss of my model is always around one, it starts with numbers like 1.002 and continues with numbers around 1.0004. I tried changing the implementation of my UNET, thinking that maybe that was the problem, but I encountered the same problem.

This is the Main file

!pip install import_ipynb
import import_ipynb       #This is a package that allows me to get functions directly from colab notebooks

# Commented out IPython magic to ensure Python compatibility.
import os
import torch
import torch.nn as nn
import argparse
import numpy as np

#%cd "/content/drive/MyDrive/Fifth year/ClearBox/Diffusion_model_training"
# %cd "/content/drive/MyDrive/Diffusion_model_training"
from model import UNET
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms import ToPILImage
import torchvision
from torchsummary import summary
from tqdm import tqdm
import sys
from PIL import Image

def sin_time_embeding(time_vector, number_channels = 256):
  batch_size = np.shape(time_vector)[0]
  emb = torch.zeros([batch_size, number_channels])
  seq_channels = torch.arange(number_channels/2).long()
  for position, t in enumerate(time_vector):
    for i in seq_channels:
      value = t/(torch.pow(10000, 2*i/(number_channels)))
      sin = (torch.sin(value)).long()
      cos = torch.cos(value).long()

      emb[position, 2*i] = sin
      emb[position, 2*i+1] = cos
  return emb

class basics:
  def __init__(self, args, number_noise_steps = 1000, beta_start = 1e-4, beta_end = 0.02, image_size = 64, device = "cuda"):
    self.number_noise_steps = number_noise_steps
    self.beta_start = beta_start
    self.beta_end = beta_end
    self.image_size = image_size
    self.device = device
    self.args = args
    self.beta = torch.linspace(self.beta_start, self.beta_end, self.number_noise_steps).to(self.device)
    
    self.alpha = 1 - self.beta
    self.big_alpha = torch.cumprod(self.alpha, dim = 0)
  
  
  def produce_noise(self, x0, time_position):
    time_position = time_position[:, None, None, None].to(self.args.device)

    part1 = (torch.sqrt(self.big_alpha[time_position]) * x0)
    part2 = (torch.sqrt(1 - self.big_alpha[time_position]) * torch.randn_like(x0))
    part3 = part1 + part2
    return part3, torch.randn_like(x0)
  

  def sampling_image(self, model, batch_size, label, classifier_scale = 3): #Labels has to have batch size
    print("Start Sampling")
    model.eval()
    x_noise = torch.randn(batch_size, 1, self.image_size, self.image_size).to(self.args.device)

    with torch.no_grad(): 
      for i in reversed(range(1, self.number_noise_steps)):
        t = (torch.ones(batch_size) * i).long().to(self.args.device)
        if i == 0:
          z = torch.zeros(x_noise.size())
        else:
          z = torch.rand_like(x_noise).to(self.args.device)

        alpha_buffer = self.alpha[t][:, None, None, None]
        big_alpha_buffer = self.big_alpha[t][:, None, None, None]
        beta_buffer = self.beta[t][:, None, None, None]
        
        
        sinusoidal_time_embeding = sin_time_embeding(t).to(self.args.device)

        pred_classified_noise = model(x_noise, sinusoidal_time_embeding, label)
        pred_noise = pred_classified_noise

        if classifier_scale > 0:  #The classifier scale is what defines the intensity of the interpolation towards the classified predicited noise
          pred_unclassified_noise = model(x_noise, sinusoidal_time_embeding, None)
          pred_interpolated_noise = torch.lerp(pred_unclassified_noise, pred_classified_noise,classifier_scale)
          pred_noise = pred_interpolated_noise

        part2 = ((1 - alpha_buffer)/(torch.sqrt(1 - big_alpha_buffer))) * pred_noise
        xtm = ((1/alpha_buffer) * (x_noise - part2)) + torch.sqrt(beta_buffer) * z
        x_noise = xtm
    model.train()
    return xtm

def train(args, model,dataloader, optmizer, loss, model_checkpoint = None):#Need to take it out of the basics object and create args 
  basic_obj = basics(args, args.number_noise_steps, args.beta_start, args.beta_end, args.image_size, args.device)
  
  if args.use_checkpoints == "True" and model_checkpoint != None:  #Load the checkpoints of the model
    print("Using checkpoint")
    model.load_state_dict(model_checkpoint['model_state_dict'])
    optmizer.load_state_dict(model_checkpoint['optimizer_state_dict'])
    epoch = model_checkpoint["epoch"]
  else:
    epoch = 0

  while epoch < args.number_epochs:
    print("Epoch: ", epoch)
    list_losses = []
    for i, data in tqdm(enumerate(dataloader)):   #Iterating over the images from the dataloader
      optmizer.zero_grad()             #Setting gradient to zero after each iteration
      
      label = data[1].to(args.device)
      x0 = data[0].to(args.device)

      t = torch.randint(1, args.number_noise_steps, (args.batch_size, )).to(args.device)  #Getting a vector of time values the size of the bactch
      
      xt_rand, normal_distribution = basic_obj.produce_noise(x0, t)   #Generaring the noisy image at the specified time stamps from vector "t"
      xt_rand = xt_rand.to(args.device)

      normal_distribution = normal_distribution.to(args.device)

      sinusoidal_time_embeding = sin_time_embeding(t).to(args.device) #This needs to be done because the UNET only accepts the time tensor when it is trannformed

      if torch.rand(1) < 0.1:
        label = None

      x_pred = model(xt_rand, sinusoidal_time_embeding, label).to(args.device)    #Predicted images from the UNET by inputing the image and the time without the sinusoidal embeding
      
      Lsimple = loss(normal_distribution, x_pred).to(args.device)
      list_losses.append(Lsimple.item())
      Lsimple.backward()
      optmizer.step()

    epoch += 1
    
    #Saving Checkpoint
    try:    
      EPOCH = epoch
      PATH = args.checkpoint_directory + "/DiffusionModel.pt"      
      torch.save({
          'epoch': EPOCH,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optmizer.state_dict(),
          'LRSeg': args.learning_rate,
          }, PATH)
      print("checkpoint saved")
    except:
      pass

    labels_to_predict = torch.tensor(5).to(args.device)
    image_sample = basic_obj.sampling_image(model, 2, labels_to_predict)
    image_sample1 = torch.squeeze(image_sample[0])

    trsmr = ToPILImage()
    img_pil1 = trsmr(image_sample1)

    display(img_pil1)
    print("The average loss was: ", np.mean(list_losses))

def main(params):
  parser = argparse.ArgumentParser(description='Diffusion model')

  parser.add_argument('--device', type=str, default="cuda", help='Device to run the code on')
  parser.add_argument('--use_checkpoints', type=str, default="True", help='Use checkpoints')
  parser.add_argument('--emb_dimension', type=int, default=256, help='Number of embeded time dimension')
  parser.add_argument('--number_noise_steps', type=int, default=1000, help='Numbe of steps required to noise the image')
  parser.add_argument('--beta_start', type=float, default=1e-4, help='First value of beta')
  parser.add_argument('--beta_end', type=float, default=0.02, help='Last value of beta')
  parser.add_argument('--beta_curve', type=str, default="linear", help='How the value of beta will change over time')
  parser.add_argument('--image_size', type=int, default=32, help='Size of the squared input image')
  parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
  parser.add_argument('--number_workers', type=int, default=2, help='Number of workers for the dataloader')
  parser.add_argument('--number_steps', type=int, default=200, help='How many iterations steps the model will learn from')
  parser.add_argument('--number_epochs', type=int, default=10, help='Number of epochs the model will learn from')
  parser.add_argument('--learning_rate', type=float, default=1e-4, help='Initial learning rate of the optmizer')
  parser.add_argument('--number_classes', type=int, default=10, help='Number of classes for the classifier')
  parser.add_argument('--checkpoint_directory', \
  type=str, default="/content/drive/MyDrive/Fifth year/ClearBox/Diffusion_model_training/checkpoints", help='')

  args = parser.parse_args(params)

  #Import the Mninst dataset for training and validation
  dataset_train = MNIST("/content/MNIST_train", download=True, train=True,transform=torchvision.transforms.Compose([torchvision.transforms.Resize(32), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(0.5, 0.5)]),)

  dataloader_train = DataLoader(dataset_train,args.batch_size, num_workers=args.number_workers, shuffle=True, drop_last=True)
  args = parser.parse_args(params)

  diffusion_model = UNET(args, 1,1,number_classes_input=args.number_classes).to(args.device)

  optmizer = torch.optim.Adam(diffusion_model.parameters(), lr=args.learning_rate, betas=(0.9, 0.99))

  loss_mse = nn.MSELoss()
  
  #Doing the calculation for the number of iterations
  size_iterations = len(dataloader_train.dataset)/args.batch_size
  params_update = ['--number_steps', str(int(size_iterations))]
  params = params + params_update
  args = parser.parse_args(params)
  print("Number iterations: ", size_iterations)

  train(args, diffusion_model, dataloader_train, optmizer, loss_mse)

  #Needs to define a lot of things still, dataset class, dataloader, model

if __name__ == "__main__":  
  main(["--device", "cuda",
        "--batch_size", "32"])

This is the model

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import sys
from torch.nn import ZeroPad2d



#Predefined blocks

#Double convolution with Gelu activation
#The size of the input is going to be size of the output, except for the number of channels
class d_conv(nn.Module):
  def __init__(self, in_channels, out_channels, mid_channels = False, residual = False):    #Residuals define if there is going to be a residual connection from the input to the output
    super().__init__()

    self.residual = residual

    if not mid_channels:
      mid_channels = out_channels

    self.double_conv = nn.Sequential(
      nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
      nn.GroupNorm(1, mid_channels),
      nn.GELU(),
      nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
      nn.GroupNorm(1, out_channels),)
    
  def forward(self, x):
    if self.residual:
      output = F.gelu(x + self.double_conv(x))
    else:
      output = self.double_conv(x)
    return output

#Define the self attention block

class SA(nn.Module):    #Standard attention block
  def __init__(self, in_channels, dimension_input_image):
    super().__init__()
    self.in_channels = in_channels
    self.dimension_input_image = int(dimension_input_image)
    self.mha = nn.MultiheadAttention(in_channels, 4, batch_first=True)
    self.layer_norm = nn.LayerNorm([in_channels])

    self.feed_forward = nn.Sequential(
      nn.LayerNorm([in_channels]),
      nn.Linear(in_channels, in_channels),
      nn.GELU(),
      nn.Linear(in_channels, in_channels),
    )
  def forward(self, x):
    
    x = x.view(-1, self.in_channels, int(self.dimension_input_image**2)).swapaxes(1,2)
    
    norm_x = self.layer_norm(x)
    attention_value, _ = self.mha(norm_x, norm_x, norm_x)
    attention_value = attention_value + x
    attention_value = self.feed_forward(attention_value) + attention_value
    output = attention_value.swapaxes(2, 1).view(-1, self.in_channels, self.dimension_input_image, self.dimension_input_image)
    return output

#Defining Downsaple, midsample, and upsample blocks

class DS(nn.Module):
  def __init__(self, in_channels, out_channels, size_time_dimension = 256):
    super().__init__()
    self.pool_layer = nn.Sequential(   #Pooling layer with some convolutional blocks to change the number of channels and add the resitual
        nn.MaxPool2d(2),
        d_conv(in_channels, in_channels, residual=True),
        d_conv(in_channels, out_channels),
    )
    self.emb_layer = nn.Sequential(    #It is needed to change the dimensions of the time tensor in order to sum it to the output of the pooling layer
        nn.SiLU(),  #Activation function
        nn.Linear(size_time_dimension, out_channels)
    )
  def forward(self, x, time_tensor):
    out1 = self.pool_layer(x)
    out2 = self.emb_layer(time_tensor)[:, :, None, None].repeat(1,1,out1.shape[-2], out1.shape[-1]) #In this line we first match the number of time dimensions wiht the number of channels
                                                                                                    #Then this "[:, :, None, None].repeat(1,1,x.shape[-2], x.shape[-1])" part server to match the size of time tensor to the size of the outputx
    return (out1 + out2)


class US(nn.Module):
  def __init__(self, in_channels, out_channels, size_time_dimension = 256):
    super().__init__()

    self.up_block = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) #Block to upsample the resolution of the input

    self.conv_block = nn.Sequential(    #Convolutional block
        d_conv(in_channels, in_channels, residual=True),
        d_conv(in_channels, out_channels),
    )
  
    self.emb_layer = nn.Sequential(    #It is needed to change the dimensions of the time tensor in order to sum it to the output of the pooling layer
          nn.SiLU(),  #Activation function
          nn.Linear(size_time_dimension, out_channels),
      )
  def forward(self, x, skip_input, time_tensor):
    out1 = self.up_block(x)
    out2 = torch.cat([skip_input, out1], dim=1)
    out3 = self.conv_block(out2)
    out4 = self.emb_layer(time_tensor)[:, :, None, None].repeat(1,1,out3.shape[-2], out3.shape[-1]) #In this line we first match the number of time dimensions wiht the number of channels
                                                                                                    #Then this "[:, :, None, None].repeat(1,1,x.shape[-2], x.shape[-1])" part serves to match the size of time tensor to the size of the outputx
    return out3 + out4

class UNET(nn.Module):
  def __init__(self, args, in_channels = 1 , out_channels = 1, size_time_dimension = 256, number_classes_input = None, dimension_input_image = 32):
    super().__init__()
    self.args = args
    
    self.size_time_dimension = size_time_dimension
    self.dimension_input_image = dimension_input_image
    #Now we start the definition of the layers with the especific dimensions of the channels
    self.in_channel = d_conv(in_channels, 64)
    self.down1 = DS(64, 128,self.size_time_dimension)
    self.sa1 = SA(128, self.dimension_input_image/2)
    self.down2 = DS(128, 256,self.size_time_dimension)
    self.sa2 = SA(256, self.dimension_input_image/4)
    self.down3 = DS(256, 256,self.size_time_dimension)
    self.sa3 = SA(256, self.dimension_input_image/8)

    self.bot1 = d_conv(256, 512)
    self.bot2 = d_conv(512, 512)
    self.bot3 = d_conv(512, 256)

    self.up1 = US(512, 128,self.size_time_dimension)
    self.sa4 = SA(128, self.dimension_input_image/4)
    self.up2 = US(256, 64,self.size_time_dimension)
    self.sa5 = SA(64, self.dimension_input_image/2)
    self.up3 = US(128, 64,self.size_time_dimension)
    self.sa6 = SA(64, self.dimension_input_image)
    self.out_channel = nn.Conv2d(64, out_channels, kernel_size=1)

    if number_classes_input is not None:
      self.classification_embeding = nn.Embedding(number_classes_input, int(size_time_dimension))
  
  def forward(self, x_input, time_tensor, y_input):

        if y_input is not None:
          embeded_labels = self.classification_embeding(y_input)
          time_tensor = time_tensor + embeded_labels


        x1 = self.in_channel(x_input)
        x2 = self.down1(x1, time_tensor)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, time_tensor)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, time_tensor)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x5 = self.up1(x4, x3, time_tensor)
        x5 = self.sa4(x5)
        x5 = self.up2(x5, x2, time_tensor)
        x5 = self.sa5(x5)
        x5 = self.up3(x5, x1, time_tensor)
        x5 = self.sa6(x5)
        output = self.out_channel(x5)
        return output