Hey everybody!
I am currently trying to implement a diffusion model using the UNET architecture in order to generate images based on labels. Since I am very new to the subject I started training it with the MNIST data set, since the images are of simpler kind.
The issue I am facing is the fact that the average loss of my model is always around one, it starts with numbers like 1.002 and continues with numbers around 1.0004. I tried changing the implementation of my UNET, thinking that maybe that was the problem, but I encountered the same problem.
This is the Main file
!pip install import_ipynb
import import_ipynb #This is a package that allows me to get functions directly from colab notebooks
# Commented out IPython magic to ensure Python compatibility.
import os
import torch
import torch.nn as nn
import argparse
import numpy as np
#%cd "/content/drive/MyDrive/Fifth year/ClearBox/Diffusion_model_training"
# %cd "/content/drive/MyDrive/Diffusion_model_training"
from model import UNET
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms import ToPILImage
import torchvision
from torchsummary import summary
from tqdm import tqdm
import sys
from PIL import Image
def sin_time_embeding(time_vector, number_channels = 256):
batch_size = np.shape(time_vector)[0]
emb = torch.zeros([batch_size, number_channels])
seq_channels = torch.arange(number_channels/2).long()
for position, t in enumerate(time_vector):
for i in seq_channels:
value = t/(torch.pow(10000, 2*i/(number_channels)))
sin = (torch.sin(value)).long()
cos = torch.cos(value).long()
emb[position, 2*i] = sin
emb[position, 2*i+1] = cos
return emb
class basics:
def __init__(self, args, number_noise_steps = 1000, beta_start = 1e-4, beta_end = 0.02, image_size = 64, device = "cuda"):
self.number_noise_steps = number_noise_steps
self.beta_start = beta_start
self.beta_end = beta_end
self.image_size = image_size
self.device = device
self.args = args
self.beta = torch.linspace(self.beta_start, self.beta_end, self.number_noise_steps).to(self.device)
self.alpha = 1 - self.beta
self.big_alpha = torch.cumprod(self.alpha, dim = 0)
def produce_noise(self, x0, time_position):
time_position = time_position[:, None, None, None].to(self.args.device)
part1 = (torch.sqrt(self.big_alpha[time_position]) * x0)
part2 = (torch.sqrt(1 - self.big_alpha[time_position]) * torch.randn_like(x0))
part3 = part1 + part2
return part3, torch.randn_like(x0)
def sampling_image(self, model, batch_size, label, classifier_scale = 3): #Labels has to have batch size
print("Start Sampling")
model.eval()
x_noise = torch.randn(batch_size, 1, self.image_size, self.image_size).to(self.args.device)
with torch.no_grad():
for i in reversed(range(1, self.number_noise_steps)):
t = (torch.ones(batch_size) * i).long().to(self.args.device)
if i == 0:
z = torch.zeros(x_noise.size())
else:
z = torch.rand_like(x_noise).to(self.args.device)
alpha_buffer = self.alpha[t][:, None, None, None]
big_alpha_buffer = self.big_alpha[t][:, None, None, None]
beta_buffer = self.beta[t][:, None, None, None]
sinusoidal_time_embeding = sin_time_embeding(t).to(self.args.device)
pred_classified_noise = model(x_noise, sinusoidal_time_embeding, label)
pred_noise = pred_classified_noise
if classifier_scale > 0: #The classifier scale is what defines the intensity of the interpolation towards the classified predicited noise
pred_unclassified_noise = model(x_noise, sinusoidal_time_embeding, None)
pred_interpolated_noise = torch.lerp(pred_unclassified_noise, pred_classified_noise,classifier_scale)
pred_noise = pred_interpolated_noise
part2 = ((1 - alpha_buffer)/(torch.sqrt(1 - big_alpha_buffer))) * pred_noise
xtm = ((1/alpha_buffer) * (x_noise - part2)) + torch.sqrt(beta_buffer) * z
x_noise = xtm
model.train()
return xtm
def train(args, model,dataloader, optmizer, loss, model_checkpoint = None):#Need to take it out of the basics object and create args
basic_obj = basics(args, args.number_noise_steps, args.beta_start, args.beta_end, args.image_size, args.device)
if args.use_checkpoints == "True" and model_checkpoint != None: #Load the checkpoints of the model
print("Using checkpoint")
model.load_state_dict(model_checkpoint['model_state_dict'])
optmizer.load_state_dict(model_checkpoint['optimizer_state_dict'])
epoch = model_checkpoint["epoch"]
else:
epoch = 0
while epoch < args.number_epochs:
print("Epoch: ", epoch)
list_losses = []
for i, data in tqdm(enumerate(dataloader)): #Iterating over the images from the dataloader
optmizer.zero_grad() #Setting gradient to zero after each iteration
label = data[1].to(args.device)
x0 = data[0].to(args.device)
t = torch.randint(1, args.number_noise_steps, (args.batch_size, )).to(args.device) #Getting a vector of time values the size of the bactch
xt_rand, normal_distribution = basic_obj.produce_noise(x0, t) #Generaring the noisy image at the specified time stamps from vector "t"
xt_rand = xt_rand.to(args.device)
normal_distribution = normal_distribution.to(args.device)
sinusoidal_time_embeding = sin_time_embeding(t).to(args.device) #This needs to be done because the UNET only accepts the time tensor when it is trannformed
if torch.rand(1) < 0.1:
label = None
x_pred = model(xt_rand, sinusoidal_time_embeding, label).to(args.device) #Predicted images from the UNET by inputing the image and the time without the sinusoidal embeding
Lsimple = loss(normal_distribution, x_pred).to(args.device)
list_losses.append(Lsimple.item())
Lsimple.backward()
optmizer.step()
epoch += 1
#Saving Checkpoint
try:
EPOCH = epoch
PATH = args.checkpoint_directory + "/DiffusionModel.pt"
torch.save({
'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optmizer.state_dict(),
'LRSeg': args.learning_rate,
}, PATH)
print("checkpoint saved")
except:
pass
labels_to_predict = torch.tensor(5).to(args.device)
image_sample = basic_obj.sampling_image(model, 2, labels_to_predict)
image_sample1 = torch.squeeze(image_sample[0])
trsmr = ToPILImage()
img_pil1 = trsmr(image_sample1)
display(img_pil1)
print("The average loss was: ", np.mean(list_losses))
def main(params):
parser = argparse.ArgumentParser(description='Diffusion model')
parser.add_argument('--device', type=str, default="cuda", help='Device to run the code on')
parser.add_argument('--use_checkpoints', type=str, default="True", help='Use checkpoints')
parser.add_argument('--emb_dimension', type=int, default=256, help='Number of embeded time dimension')
parser.add_argument('--number_noise_steps', type=int, default=1000, help='Numbe of steps required to noise the image')
parser.add_argument('--beta_start', type=float, default=1e-4, help='First value of beta')
parser.add_argument('--beta_end', type=float, default=0.02, help='Last value of beta')
parser.add_argument('--beta_curve', type=str, default="linear", help='How the value of beta will change over time')
parser.add_argument('--image_size', type=int, default=32, help='Size of the squared input image')
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
parser.add_argument('--number_workers', type=int, default=2, help='Number of workers for the dataloader')
parser.add_argument('--number_steps', type=int, default=200, help='How many iterations steps the model will learn from')
parser.add_argument('--number_epochs', type=int, default=10, help='Number of epochs the model will learn from')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Initial learning rate of the optmizer')
parser.add_argument('--number_classes', type=int, default=10, help='Number of classes for the classifier')
parser.add_argument('--checkpoint_directory', \
type=str, default="/content/drive/MyDrive/Fifth year/ClearBox/Diffusion_model_training/checkpoints", help='')
args = parser.parse_args(params)
#Import the Mninst dataset for training and validation
dataset_train = MNIST("/content/MNIST_train", download=True, train=True,transform=torchvision.transforms.Compose([torchvision.transforms.Resize(32), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(0.5, 0.5)]),)
dataloader_train = DataLoader(dataset_train,args.batch_size, num_workers=args.number_workers, shuffle=True, drop_last=True)
args = parser.parse_args(params)
diffusion_model = UNET(args, 1,1,number_classes_input=args.number_classes).to(args.device)
optmizer = torch.optim.Adam(diffusion_model.parameters(), lr=args.learning_rate, betas=(0.9, 0.99))
loss_mse = nn.MSELoss()
#Doing the calculation for the number of iterations
size_iterations = len(dataloader_train.dataset)/args.batch_size
params_update = ['--number_steps', str(int(size_iterations))]
params = params + params_update
args = parser.parse_args(params)
print("Number iterations: ", size_iterations)
train(args, diffusion_model, dataloader_train, optmizer, loss_mse)
#Needs to define a lot of things still, dataset class, dataloader, model
if __name__ == "__main__":
main(["--device", "cuda",
"--batch_size", "32"])
This is the model
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import sys
from torch.nn import ZeroPad2d
#Predefined blocks
#Double convolution with Gelu activation
#The size of the input is going to be size of the output, except for the number of channels
class d_conv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels = False, residual = False): #Residuals define if there is going to be a residual connection from the input to the output
super().__init__()
self.residual = residual
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(1, mid_channels),
nn.GELU(),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(1, out_channels),)
def forward(self, x):
if self.residual:
output = F.gelu(x + self.double_conv(x))
else:
output = self.double_conv(x)
return output
#Define the self attention block
class SA(nn.Module): #Standard attention block
def __init__(self, in_channels, dimension_input_image):
super().__init__()
self.in_channels = in_channels
self.dimension_input_image = int(dimension_input_image)
self.mha = nn.MultiheadAttention(in_channels, 4, batch_first=True)
self.layer_norm = nn.LayerNorm([in_channels])
self.feed_forward = nn.Sequential(
nn.LayerNorm([in_channels]),
nn.Linear(in_channels, in_channels),
nn.GELU(),
nn.Linear(in_channels, in_channels),
)
def forward(self, x):
x = x.view(-1, self.in_channels, int(self.dimension_input_image**2)).swapaxes(1,2)
norm_x = self.layer_norm(x)
attention_value, _ = self.mha(norm_x, norm_x, norm_x)
attention_value = attention_value + x
attention_value = self.feed_forward(attention_value) + attention_value
output = attention_value.swapaxes(2, 1).view(-1, self.in_channels, self.dimension_input_image, self.dimension_input_image)
return output
#Defining Downsaple, midsample, and upsample blocks
class DS(nn.Module):
def __init__(self, in_channels, out_channels, size_time_dimension = 256):
super().__init__()
self.pool_layer = nn.Sequential( #Pooling layer with some convolutional blocks to change the number of channels and add the resitual
nn.MaxPool2d(2),
d_conv(in_channels, in_channels, residual=True),
d_conv(in_channels, out_channels),
)
self.emb_layer = nn.Sequential( #It is needed to change the dimensions of the time tensor in order to sum it to the output of the pooling layer
nn.SiLU(), #Activation function
nn.Linear(size_time_dimension, out_channels)
)
def forward(self, x, time_tensor):
out1 = self.pool_layer(x)
out2 = self.emb_layer(time_tensor)[:, :, None, None].repeat(1,1,out1.shape[-2], out1.shape[-1]) #In this line we first match the number of time dimensions wiht the number of channels
#Then this "[:, :, None, None].repeat(1,1,x.shape[-2], x.shape[-1])" part server to match the size of time tensor to the size of the outputx
return (out1 + out2)
class US(nn.Module):
def __init__(self, in_channels, out_channels, size_time_dimension = 256):
super().__init__()
self.up_block = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) #Block to upsample the resolution of the input
self.conv_block = nn.Sequential( #Convolutional block
d_conv(in_channels, in_channels, residual=True),
d_conv(in_channels, out_channels),
)
self.emb_layer = nn.Sequential( #It is needed to change the dimensions of the time tensor in order to sum it to the output of the pooling layer
nn.SiLU(), #Activation function
nn.Linear(size_time_dimension, out_channels),
)
def forward(self, x, skip_input, time_tensor):
out1 = self.up_block(x)
out2 = torch.cat([skip_input, out1], dim=1)
out3 = self.conv_block(out2)
out4 = self.emb_layer(time_tensor)[:, :, None, None].repeat(1,1,out3.shape[-2], out3.shape[-1]) #In this line we first match the number of time dimensions wiht the number of channels
#Then this "[:, :, None, None].repeat(1,1,x.shape[-2], x.shape[-1])" part serves to match the size of time tensor to the size of the outputx
return out3 + out4
class UNET(nn.Module):
def __init__(self, args, in_channels = 1 , out_channels = 1, size_time_dimension = 256, number_classes_input = None, dimension_input_image = 32):
super().__init__()
self.args = args
self.size_time_dimension = size_time_dimension
self.dimension_input_image = dimension_input_image
#Now we start the definition of the layers with the especific dimensions of the channels
self.in_channel = d_conv(in_channels, 64)
self.down1 = DS(64, 128,self.size_time_dimension)
self.sa1 = SA(128, self.dimension_input_image/2)
self.down2 = DS(128, 256,self.size_time_dimension)
self.sa2 = SA(256, self.dimension_input_image/4)
self.down3 = DS(256, 256,self.size_time_dimension)
self.sa3 = SA(256, self.dimension_input_image/8)
self.bot1 = d_conv(256, 512)
self.bot2 = d_conv(512, 512)
self.bot3 = d_conv(512, 256)
self.up1 = US(512, 128,self.size_time_dimension)
self.sa4 = SA(128, self.dimension_input_image/4)
self.up2 = US(256, 64,self.size_time_dimension)
self.sa5 = SA(64, self.dimension_input_image/2)
self.up3 = US(128, 64,self.size_time_dimension)
self.sa6 = SA(64, self.dimension_input_image)
self.out_channel = nn.Conv2d(64, out_channels, kernel_size=1)
if number_classes_input is not None:
self.classification_embeding = nn.Embedding(number_classes_input, int(size_time_dimension))
def forward(self, x_input, time_tensor, y_input):
if y_input is not None:
embeded_labels = self.classification_embeding(y_input)
time_tensor = time_tensor + embeded_labels
x1 = self.in_channel(x_input)
x2 = self.down1(x1, time_tensor)
x2 = self.sa1(x2)
x3 = self.down2(x2, time_tensor)
x3 = self.sa2(x3)
x4 = self.down3(x3, time_tensor)
x4 = self.sa3(x4)
x4 = self.bot1(x4)
x4 = self.bot2(x4)
x4 = self.bot3(x4)
x5 = self.up1(x4, x3, time_tensor)
x5 = self.sa4(x5)
x5 = self.up2(x5, x2, time_tensor)
x5 = self.sa5(x5)
x5 = self.up3(x5, x1, time_tensor)
x5 = self.sa6(x5)
output = self.out_channel(x5)
return output