import torch.nn.functional as F
class DeconvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(2, 2), stride=2, padding=0):
"""
A deconvolutional block for upsampling in the decoder.
:param in_channels: The number of input channels.
:param out_channels: The number of output channels.
:param kernel_size: Kernel size for the transposed convolution, default is (2, 2).
:param stride: Stride for the transposed convolution, default is 2.
:param padding: Padding for the transposed convolution, default is 0.
"""
super(DeconvBlock, self).__init__()
self.block = nn.Sequential(
nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
nn.BatchNorm2d(out_channels),
# nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=1, padding='same'):
"""
A convolutional block used in the U-Net or similar architectures.
It consists of two convolutional layers, each followed by a batch normalization and a ReLU activation function.
:param in_channels: The number of input channels.
:param out_channels: The number of output channels.
:param kernel_size: The size of the kernel used in the convolution operation, default is (3,3).
:param stride: The stride of the convolution operation, default is 1.
:param padding: The type of padding, default is 'same'.
"""
super(ConvBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class Sampling(nn.Module):
def forward(self, inputs):
z_mean, z_log_var = inputs
batch, dim = z_mean.size()
epsilon = torch.randn(batch, dim, device=z_mean.device)
return z_mean + torch.exp(0.5 * z_log_var) * epsilon
"""old implementation"""
# class VAE(nn.Module):
# def __init__(self, input_channels, latent_dim):
# super(VAE, self).__init__()
# # Encoder
# self.encoder = nn.Sequential(
# ConvBlock(input_channels, 64),
# nn.MaxPool2d(2), # Downsampling
# ConvBlock(64, 128),
# nn.MaxPool2d(2), # Further downsampling
# ConvBlock(128, 256),
# nn.MaxPool2d(2) # Further downsampling
# )
# self.fc_mu = nn.Linear(256 * 32 * 32, latent_dim)
# self.fc_logvar = nn.Linear(256 * 32 * 32, latent_dim)
# # Decoder
# self.decoder_fc = nn.Linear(latent_dim, 256 * 32 * 32)
# self.decoder = nn.Sequential(
# DeconvBlock(256, 128), # First upsampling
# DeconvBlock(128, 64),
# DeconvBlock(64, input_channels), # Output to match input channels
# # nn.Sigmoid() # Final activation for normalized output
# )
# def reparameterize(self, mu, logvar):
# std = torch.exp(0.5 * logvar)
# epsilon = torch.randn_like(std)
# return mu + epsilon * std
# def forward(self, x):
# # Encode
# encoded = self.encoder(x)
# encoded = encoded.view(encoded.size(0), -1) # Flatten
# mu = self.fc_mu(encoded)
# logvar = self.fc_logvar(encoded)
# # Reparameterize
# z = self.reparameterize(mu, logvar)
# # Decode
# decoded = self.decoder_fc(z)
# decoded = decoded.view(-1, 256, 32,32) # Reshape to feature map size
# reconstructed = self.decoder(decoded)
# return reconstructed, mu, logvar
# class vae_encoder(nn.Module):
# def __init__(self, input_channels, channel_list, embedding_dim,):
# # call the parent constructor
# super(vae_encoder, self).__init__()
# self.input_channels = input_channels
# self.channel_list = channel_list
# self.embedding_dim = embedding_dim
# self.conv_blocks = nn.Sequential(*[convblock(self.channel_list[i], self.channel_list[i+1])
# if i != len(self.channel_list) -1
# else None for i in range(len(self.channel_list))])
"""New implementaion"""
class VAE(nn.Module):
def __init__(self, input_channels, latent_dim):
super(VAE, self).__init__()
# Encoder
self.encoder = nn.Sequential(
ConvBlock(input_channels, 64, stride=1),
nn.MaxPool2d(2), # Downsampling
ConvBlock(64, 128, stride=1),
nn.MaxPool2d(2), # Further downsampling
ConvBlock(128, 256, stride=1),
nn.MaxPool2d(2) # Further downsampling
)
self.fc_mu = nn.Linear(256 * 32 * 32, latent_dim)
self.fc_logvar = nn.Linear(256 * 32 * 32, latent_dim)
# Decoder
self.decoder_fc = nn.Linear(latent_dim, 256 * 32 * 32)
self.decoder = nn.Sequential(
DeconvBlock(256, 128), # First upsampling
DeconvBlock(128, 64),
DeconvBlock(64, input_channels), # Output to match input channels
# nn.Sigmoid() # Final activation for normalized output
)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
epsilon = torch.randn_like(std)
return mu + epsilon * std
def forward(self, x):
# Encode
encoded = self.encoder(x)
encoded = encoded.view(encoded.size(0), -1) # Flatten
mu = self.fc_mu(encoded)
logvar = self.fc_logvar(encoded)
# Reparameterize
z = self.reparameterize(mu, logvar)
# Decode
decoded = self.decoder_fc(z)
decoded = decoded.view(-1, 256, 32,32) # Reshape to feature map size
# print(decoded.shape)
decoded = self.decoder[0](decoded)
# print(self.decoder[0])
decoded = F.relu(decoded)
decoded = self.decoder[1](decoded)
decoded = F.relu(decoded)
reconstructed = F.relu(self.decoder[2](decoded)) # Add a second linear layer before reshaping
# decoded = decoded.view(-1, 256, 32,32) # Reshape to feature map size
# reconstructed
return reconstructed, mu, logvar
# torch.manual_seed(seed)
# random.seed(seed)
# np.random.seed(seed)
latent_dim = 16
input_channels = 3
test = torch.rand(1, 3, 256, 256)
vae_model = VAE(input_channels, latent_dim)
reconstructed, mu, logvar = vae_model(test)
# print(reconstructed.shape)
# plt.imshow(reconstructed.squeeze(0).reshape(256, 256, 3).detach().numpy())
# torch.manual_seed(seed)
# random.seed(seed)
# np.random.seed(seed)
device = 'cuda'
optimizer = torch.optim.Adam(vae_model.parameters(), lr = 0.00001)
criterion = nn.CrossEntropyLoss(reduction='none')
# seed = 78
# torch.manual_seed(seed)
# random.seed(seed)
# np.random.seed(seed)
vae_model.to(device)
class EarlyStopper:
"""
This class provides an early stopping mechanism to prevent overfitting during model training.
If the validation loss does not decrease for a specified number of epochs (patience), the training is stopped.
The model state with the lowest validation loss is saved and can be loaded for future use.
"""
def __init__(self, model, weights_name, patience=1, min_delta=0):
"""
Initializes the EarlyStopper.
Args:
model (nn.Module): The PyTorch model to be trained.
weights_name (str): The name of the file where the best model weights will be saved.
patience (int, optional): The number of epochs to wait for the validation loss to decrease. Defaults to 1.
min_delta (int, optional): The minimum decrease in validation loss to be considered an improvement. Defaults to 0.
"""
self.model = model
self.weights_name = weights_name
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.min_validation_loss = np.inf
def early_stop(self, validation_loss):
"""
Checks if the validation loss has decreased and saves the model weights if it has.
If the validation loss has not decreased for a specified number of epochs (patience), the training is stopped.
Args:
validation_loss (float): The current validation loss.
Returns:
bool: True if the training should be stopped, False otherwise.
"""
# If the validation loss has decreased
if validation_loss < self.min_validation_loss:
# Update the minimum validation loss
self.min_validation_loss = validation_loss
# Save the current model state
torch.save(self.model.state_dict(), self.weights_name+'.pt')
print(f'Saving the best weights at validation loss: {self.min_validation_loss}\n\n')
# Reset the counter
self.counter = 0
# If the validation loss has not decreased enough
elif validation_loss > (self.min_validation_loss + self.min_delta):
# Increment the counter
self.counter += 1
# If the counter has reached the patience limit
if self.counter >= self.patience:
# Return True to stop the training
return True
# If the counter has not reached the patience limit, return False to continue the training
return False
model_weights_name = 'vae_growliflower_basic'
patience = 50
early_stopper = EarlyStopper(model=vae_model, weights_name=model_weights_name, patience=patience)
epochs = 500
device = 'cuda'
early_stop_flag = True
def plot_reconstructions(model, val_dataloader, device, epoch):
"""
Plot original and reconstructed images from the validation set and save the figure.
Args:
model: The VAE model
val_dataloader: Validation data loader
device: Device to run the model on
epoch: Current epoch number
"""
model.eval()
with torch.no_grad():
# Get a batch of images
images, _ = next(iter(val_dataloader))
images = images.to(device)
# Get reconstructions
reconstructions, _, _ = model(images)
soft_reconstructions = torch.softmax(reconstructions, dim=1)
# Select 2 random indices
idx = torch.randint(0, images.size(0), (2,))
# Create a figure with 2 rows (original and reconstruction) and 2 columns
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
for i, index in enumerate(idx):
# Original image
orig_img = images[index].cpu().permute(1, 2, 0).numpy()
# Normalize back to [0,1] range for visualization
orig_img = (orig_img - orig_img.min()) / (orig_img.max() - orig_img.min())
axes[0, i].imshow(orig_img)
axes[0, i].set_title(f'Original {i+1}')
axes[0, i].axis('off')
# Reconstructed image
recon_img = soft_reconstructions[index].cpu().permute(1, 2, 0).numpy()
# Normalize back to [0,1] range for visualization
# recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
axes[1, i].imshow(recon_img)
axes[1, i].set_title(f'Reconstructed {i+1}')
axes[1, i].axis('off')
plt.suptitle(f'Epoch {epoch+1}')
plt.tight_layout()
# Create a directory to save the images if it doesn't exist
os.makedirs('reconstruction_images', exist_ok=True)
# Save the figure
plt.savefig(f'reconstruction_images/reconstruction_epoch_{epoch+1}.png')
plt.close(fig) # Close the figure to free up memory
def training_settings(model, epochs, device, optimizer,
criterion, train_dataloader, val_dataloader, weights_name,
early_stop_flag = True, fine_tune=True):
"""
This function trains a PyTorch model for a specified number of epochs, and evaluates it on a validation set.
Args:
model (nn.Module): The PyTorch model to be trained.
epochs (int): The number of epochs to train the model.
device (str): The device (cpu or cuda) where the model and data are to be loaded.
optimizer (torch.optim.Optimizer): The optimization algorithm used to update the model parameters.
criterion (torch.nn.modules.loss._Loss): The loss function used to evaluate the model.
train_dataloader (torch.utils.data.DataLoader): The DataLoader for the training data.
val_dataloader (torch.utils.data.DataLoader): The DataLoader for the validation data.
early_stop_flag (bool, optional): If True, early stopping is applied when validation loss doesn't improve. Defaults to True.
fine_tune (bool, optional): If True, the model will be fine-tuned. Defaults to True.
Returns:
tuple: A tuple containing four lists. The first list contains the training losses for each epoch,
the second list contains the validation losses for each epoch,
the third list contains the training accuracies for each epoch,
and the fourth list contains the validation accuracies for each epoch.
"""
train_rl_loss = []
train_kl_loss = []
total_train_loss = []
val_rl_loss = []
val_kl_loss = []
total_val_loss = []
for epoch in range(epochs):
# Training from scratch
if fine_tune != True:
model.train()
# Fine tuning
else:
for name, child in model.named_children():
if name == 'features':
for sub_name, sub_child in child.named_children():
if sub_name == 'denseblock4':
for param in sub_child.parameters():
param.requires_grad = True
else:
for param in sub_child.parameters():
param.requires_grad = False
else:
for param in child.parameters():
param.requires_grad = False
for param in model.classifier.parameters():
param.requires_grad = True
train_loss = 0.0
rl_loss = 0.0
kl_loss = 0.0
for images, _ in train_dataloader:
images = images.to(device)
# labels = labels.squeeze(1)
optimizer.zero_grad()
# Forward pass
reconstruction, mu, logvar = model(images)
# Compute loss
# outputs = outputs.squeeze(1)
# print(outputs.shape)
# # Compute loss
# print(labels.shape)
rl_loss = torch.mean(torch.mean(criterion(reconstruction.float(), images.float()), dim=(1,2)))
# break
# print(rl_loss)
# break
# print(mu.shape)
# print(logvar.shape)
# print(reconstruction.shape)
kl_loss = (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)).mean()
# print(kl_loss)
# break
# Backward pass and optimization
loss = rl_loss + kl_loss
loss.backward()
optimizer.step()
# Update training loss
train_loss += loss.item()
rl_loss += rl_loss.item()
kl_loss += kl_loss.item()
# Count number of correct predictions
# _, predicted = torch.max(outputs.data, 1)
# Compute average training loss and accuracy
train_loss /= len(train_dataloader)
rl_loss /= len(train_dataloader)
kl_loss /= len(train_dataloader)
total_train_loss.append(train_loss)
train_rl_loss.append(rl_loss)
train_kl_loss.append(kl_loss)
# Validation
model.eval()
valid_loss = 0.0
rl_loss = 0.0
kl_loss = 0.0
with torch.no_grad():
for images, labels in val_dataloader:
images = images.to(device)
# labels = labels['plants'].to(device)
# Forward pass
reconstruction, mu, logvar = model(images)
rl_loss = torch.mean(torch.mean(criterion(reconstruction.float(), images.float()), dim=(1,2)))
kl_loss = (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)).mean()
loss = rl_loss + kl_loss
# Update validation loss
valid_loss += loss.item()
rl_loss += rl_loss.item()
kl_loss += kl_loss.item()
# Count number of correct predictions
# _, predicted = torch.max(outputs.data, 1)
# Compute average validation loss and accuracy
valid_loss /= len(val_dataloader)
# print(valid_accuracy)
total_val_loss.append(valid_loss)
val_rl_loss.append(rl_loss)
val_kl_loss.append(kl_loss)
if (epoch + 1) % 10 == 0:
plot_reconstructions(model, val_dataloader, device, epoch)
# Print progress
print(f"Epoch [{epoch+1}], "
f"Train Total Loss: {total_train_loss[epoch]:.4f}, "
f"Train RL Loss: {train_rl_loss[epoch]:.4f}, "
f"Train KL Loss: {train_kl_loss[epoch]:.4f}, "
f"Valid Total Loss: {total_val_loss[epoch]:.4f}, "
f"Valid RL Loss: {val_rl_loss[epoch]:.4f}, "
f"Valid KL Loss: {val_kl_loss[epoch]:.4f}")
print('1')
# Check for early stopping
# Check for early stopping
if early_stop_flag:
# If early stopping is enabled, call the early_stop method of the EarlyStopper object
# If the method returns True (i.e., the validation loss has not decreased for a specified number of epochs), break the training loop
if early_stopper.early_stop(valid_loss):
break
else:
# If early stopping is not enabled, save the model state after each epoch
torch.save(model.state_dict(), weights_name + f'_{epoch}'+'.pt')
return train_losses, val_losses,
This is my implementation to VAE but it is not working properly, can you tell me what is the problem?
Thanks