VAE-implementation

import torch.nn.functional as F

class DeconvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(2, 2), stride=2, padding=0):
        """
        A deconvolutional block for upsampling in the decoder.

        :param in_channels: The number of input channels.
        :param out_channels: The number of output channels.
        :param kernel_size: Kernel size for the transposed convolution, default is (2, 2).
        :param stride: Stride for the transposed convolution, default is 2.
        :param padding: Padding for the transposed convolution, default is 0.
        """
        super(DeconvBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
            ),
            nn.BatchNorm2d(out_channels),
            # nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=1, padding='same'):
        """
        A convolutional block used in the U-Net or similar architectures.
        It consists of two convolutional layers, each followed by a batch normalization and a ReLU activation function.

        :param in_channels: The number of input channels.
        :param out_channels: The number of output channels.
        :param kernel_size: The size of the kernel used in the convolution operation, default is (3,3).
        :param stride: The stride of the convolution operation, default is 1.
        :param padding: The type of padding, default is 'same'.
        """
        super(ConvBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class Sampling(nn.Module):
    def forward(self, inputs):
        z_mean, z_log_var = inputs
        batch, dim = z_mean.size()
        epsilon = torch.randn(batch, dim, device=z_mean.device)
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

"""old implementation"""
# class VAE(nn.Module):
#     def __init__(self, input_channels, latent_dim):
#         super(VAE, self).__init__()
#         # Encoder
#         self.encoder = nn.Sequential(
#             ConvBlock(input_channels, 64),
#             nn.MaxPool2d(2),  # Downsampling
#             ConvBlock(64, 128),
#             nn.MaxPool2d(2),  # Further downsampling
#             ConvBlock(128, 256),
#             nn.MaxPool2d(2)  # Further downsampling
#         )
#         self.fc_mu = nn.Linear(256 * 32 * 32, latent_dim)
#         self.fc_logvar = nn.Linear(256 * 32 * 32, latent_dim)

#         # Decoder
#         self.decoder_fc = nn.Linear(latent_dim, 256 * 32 * 32)
#         self.decoder = nn.Sequential(
#             DeconvBlock(256, 128),  # First upsampling
#             DeconvBlock(128, 64),
#             DeconvBlock(64, input_channels),  # Output to match input channels
#             # nn.Sigmoid()  # Final activation for normalized output
#         )

#     def reparameterize(self, mu, logvar):
#         std = torch.exp(0.5 * logvar)
#         epsilon = torch.randn_like(std)
#         return mu + epsilon * std

#     def forward(self, x):
#         # Encode
#         encoded = self.encoder(x)
#         encoded = encoded.view(encoded.size(0), -1)  # Flatten
#         mu = self.fc_mu(encoded)
#         logvar = self.fc_logvar(encoded)

#         # Reparameterize
#         z = self.reparameterize(mu, logvar)

#         # Decode
#         decoded = self.decoder_fc(z)
#         decoded = decoded.view(-1, 256, 32,32)  # Reshape to feature map size
#         reconstructed = self.decoder(decoded)

#         return reconstructed, mu, logvar


# class vae_encoder(nn.Module):
    
#     def __init__(self, input_channels,  channel_list, embedding_dim,):
#         # call the parent constructor
#         super(vae_encoder, self).__init__()

#         self.input_channels = input_channels

#         self.channel_list = channel_list

#         self.embedding_dim = embedding_dim

#         self.conv_blocks =  nn.Sequential(*[convblock(self.channel_list[i], self.channel_list[i+1]) 
#                                             if i != len(self.channel_list) -1 
#                                             else None  for i in range(len(self.channel_list))])


"""New implementaion"""
class VAE(nn.Module):
    def __init__(self, input_channels, latent_dim):
        super(VAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            ConvBlock(input_channels, 64, stride=1),
            nn.MaxPool2d(2),  # Downsampling
            ConvBlock(64, 128, stride=1),
            nn.MaxPool2d(2),  # Further downsampling
            ConvBlock(128, 256, stride=1),
            nn.MaxPool2d(2)  # Further downsampling
        )
        self.fc_mu = nn.Linear(256 * 32 * 32, latent_dim)
        self.fc_logvar = nn.Linear(256 * 32 * 32, latent_dim)

        # Decoder
        self.decoder_fc = nn.Linear(latent_dim, 256 * 32 * 32)
        self.decoder = nn.Sequential(
            DeconvBlock(256, 128),  # First upsampling
            DeconvBlock(128, 64),
            DeconvBlock(64, input_channels),  # Output to match input channels
            # nn.Sigmoid()  # Final activation for normalized output
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        return mu + epsilon * std

    def forward(self, x):
        # Encode
        encoded = self.encoder(x)
        encoded = encoded.view(encoded.size(0), -1)  # Flatten
        mu = self.fc_mu(encoded)
        logvar = self.fc_logvar(encoded)

        # Reparameterize
        z = self.reparameterize(mu, logvar)
        
        # Decode
        decoded = self.decoder_fc(z)
        decoded = decoded.view(-1, 256, 32,32)    # Reshape to feature map size
        # print(decoded.shape)
        decoded = self.decoder[0](decoded)
        # print(self.decoder[0])
        decoded = F.relu(decoded)
        decoded = self.decoder[1](decoded)
        decoded = F.relu(decoded)
        reconstructed = F.relu(self.decoder[2](decoded))  # Add a second linear layer before reshaping
        # decoded = decoded.view(-1, 256, 32,32)  # Reshape to feature map size
        # reconstructed

        return reconstructed, mu, logvar

# torch.manual_seed(seed)
# random.seed(seed)
# np.random.seed(seed)
latent_dim = 16
input_channels = 3
test = torch.rand(1, 3, 256, 256)

vae_model = VAE(input_channels, latent_dim)
reconstructed, mu, logvar = vae_model(test)

# print(reconstructed.shape)
# plt.imshow(reconstructed.squeeze(0).reshape(256, 256, 3).detach().numpy())

# torch.manual_seed(seed)
# random.seed(seed)
# np.random.seed(seed)
device = 'cuda'
optimizer = torch.optim.Adam(vae_model.parameters(), lr = 0.00001)
criterion = nn.CrossEntropyLoss(reduction='none')


# seed = 78
# torch.manual_seed(seed)
# random.seed(seed)
# np.random.seed(seed)

vae_model.to(device)
class EarlyStopper:
    """
    This class provides an early stopping mechanism to prevent overfitting during model training.
    If the validation loss does not decrease for a specified number of epochs (patience), the training is stopped.
    The model state with the lowest validation loss is saved and can be loaded for future use.
    """

    def __init__(self, model, weights_name, patience=1, min_delta=0):
        """
        Initializes the EarlyStopper.

        Args:
            model (nn.Module): The PyTorch model to be trained.
            weights_name (str): The name of the file where the best model weights will be saved.
            patience (int, optional): The number of epochs to wait for the validation loss to decrease. Defaults to 1.
            min_delta (int, optional): The minimum decrease in validation loss to be considered an improvement. Defaults to 0.
        """

        self.model = model
        self.weights_name = weights_name
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        """
        Checks if the validation loss has decreased and saves the model weights if it has.
        If the validation loss has not decreased for a specified number of epochs (patience), the training is stopped.

        Args:
            validation_loss (float): The current validation loss.

        Returns:
            bool: True if the training should be stopped, False otherwise.
        """
        # If the validation loss has decreased
        if validation_loss < self.min_validation_loss:
            # Update the minimum validation loss
            self.min_validation_loss = validation_loss
            # Save the current model state
            torch.save(self.model.state_dict(), self.weights_name+'.pt')
            print(f'Saving the best weights at validation loss: {self.min_validation_loss}\n\n')
            # Reset the counter
            self.counter = 0
        # If the validation loss has not decreased enough
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            # Increment the counter
            self.counter += 1
            # If the counter has reached the patience limit
            if self.counter >= self.patience:
                # Return True to stop the training
                return True
        # If the counter has not reached the patience limit, return False to continue the training
        return False
    

model_weights_name = 'vae_growliflower_basic'
patience = 50

early_stopper = EarlyStopper(model=vae_model, weights_name=model_weights_name, patience=patience)



epochs = 500
device = 'cuda'
early_stop_flag = True


def plot_reconstructions(model, val_dataloader, device, epoch):
    """
    Plot original and reconstructed images from the validation set and save the figure.
    
    Args:
        model: The VAE model
        val_dataloader: Validation data loader
        device: Device to run the model on
        epoch: Current epoch number
    """
    model.eval()
    with torch.no_grad():
        # Get a batch of images
        images, _ = next(iter(val_dataloader))
        images = images.to(device)
        
        # Get reconstructions
        reconstructions, _, _ = model(images)
        soft_reconstructions = torch.softmax(reconstructions, dim=1)
        # Select 2 random indices
        idx = torch.randint(0, images.size(0), (2,))
        
        # Create a figure with 2 rows (original and reconstruction) and 2 columns
        fig, axes = plt.subplots(2, 2, figsize=(10, 10))
        
        for i, index in enumerate(idx):
            # Original image
            orig_img = images[index].cpu().permute(1, 2, 0).numpy()
            # Normalize back to [0,1] range for visualization
            orig_img = (orig_img - orig_img.min()) / (orig_img.max() - orig_img.min())
            axes[0, i].imshow(orig_img)
            axes[0, i].set_title(f'Original {i+1}')
            axes[0, i].axis('off')
            
            # Reconstructed image
            recon_img = soft_reconstructions[index].cpu().permute(1, 2, 0).numpy()
            # Normalize back to [0,1] range for visualization
            # recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min())
            axes[1, i].imshow(recon_img)
            axes[1, i].set_title(f'Reconstructed {i+1}')
            axes[1, i].axis('off')
        
        plt.suptitle(f'Epoch {epoch+1}')
        plt.tight_layout()
        
        # Create a directory to save the images if it doesn't exist
        os.makedirs('reconstruction_images', exist_ok=True)
        
        # Save the figure
        plt.savefig(f'reconstruction_images/reconstruction_epoch_{epoch+1}.png')
        plt.close(fig)  # Close the figure to free up memory


def training_settings(model, epochs, device, optimizer, 
                      criterion, train_dataloader, val_dataloader, weights_name, 
                      early_stop_flag = True, fine_tune=True):
    """
    This function trains a PyTorch model for a specified number of epochs, and evaluates it on a validation set.

    Args:
        model (nn.Module): The PyTorch model to be trained.
        epochs (int): The number of epochs to train the model.
        device (str): The device (cpu or cuda) where the model and data are to be loaded.
        optimizer (torch.optim.Optimizer): The optimization algorithm used to update the model parameters.
        criterion (torch.nn.modules.loss._Loss): The loss function used to evaluate the model.
        train_dataloader (torch.utils.data.DataLoader): The DataLoader for the training data.
        val_dataloader (torch.utils.data.DataLoader): The DataLoader for the validation data.
        early_stop_flag (bool, optional): If True, early stopping is applied when validation loss doesn't improve. Defaults to True.
        fine_tune (bool, optional): If True, the model will be fine-tuned. Defaults to True.

    Returns:
        tuple: A tuple containing four lists. The first list contains the training losses for each epoch, 
               the second list contains the validation losses for each epoch, 
               the third list contains the training accuracies for each epoch, 
               and the fourth list contains the validation accuracies for each epoch.
    """
    train_rl_loss = []
    train_kl_loss = []
    total_train_loss = []

    val_rl_loss = []
    val_kl_loss = []
    total_val_loss = []
    

    for epoch in range(epochs):
        # Training from scratch
        if fine_tune != True:
            model.train()
        # Fine tuning
        else:
            for name, child in model.named_children():
                if name == 'features':
                    for sub_name, sub_child in child.named_children():
                        if sub_name == 'denseblock4':
                             for param in sub_child.parameters():
                                param.requires_grad = True
                        else:
                            for param in sub_child.parameters():
                                param.requires_grad = False 
                else:
                    for param in child.parameters():
                        param.requires_grad = False

            for param in model.classifier.parameters():
                param.requires_grad = True   
        
        train_loss = 0.0
        rl_loss = 0.0
        kl_loss = 0.0

        for images, _ in train_dataloader:
            images = images.to(device)
            
            # labels = labels.squeeze(1)
            
            optimizer.zero_grad()
            # Forward pass
            reconstruction, mu, logvar = model(images)
            # Compute loss
            # outputs = outputs.squeeze(1)
            # print(outputs.shape)
            #     # Compute loss
            # print(labels.shape)
            rl_loss =  torch.mean(torch.mean(criterion(reconstruction.float(), images.float()), dim=(1,2)))
            # break
            # print(rl_loss)
            # break
            # print(mu.shape)
            # print(logvar.shape)
            # print(reconstruction.shape)
            kl_loss = (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)).mean()
            # print(kl_loss)
            # break
            # Backward pass and optimization
            loss = rl_loss +  kl_loss
            loss.backward()
            optimizer.step()

            # Update training loss
            train_loss += loss.item()
            rl_loss += rl_loss.item()
            kl_loss += kl_loss.item()
            # Count number of correct predictions
            # _, predicted = torch.max(outputs.data, 1)

            

        # Compute average training loss and accuracy
        train_loss /= len(train_dataloader)
        rl_loss /=  len(train_dataloader)
        kl_loss /= len(train_dataloader)

        total_train_loss.append(train_loss)
        train_rl_loss.append(rl_loss)
        train_kl_loss.append(kl_loss)

        # Validation
        model.eval()
        valid_loss = 0.0
        rl_loss = 0.0
        kl_loss = 0.0

        with torch.no_grad():
            for images, labels in val_dataloader:
                images = images.to(device)
                # labels = labels['plants'].to(device)
                
                # Forward pass
                reconstruction, mu, logvar = model(images)

                rl_loss = torch.mean(torch.mean(criterion(reconstruction.float(), images.float()), dim=(1,2)))
                kl_loss = (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)).mean()
                
                loss = rl_loss +  kl_loss
                # Update validation loss
                valid_loss += loss.item()
                rl_loss += rl_loss.item()
                kl_loss += kl_loss.item()
                # Count number of correct predictions
                # _, predicted = torch.max(outputs.data, 1)
                

        # Compute average validation loss and accuracy
        valid_loss /= len(val_dataloader)
        # print(valid_accuracy)
        total_val_loss.append(valid_loss)
        val_rl_loss.append(rl_loss)
        val_kl_loss.append(kl_loss)
        
        if (epoch + 1) % 10 == 0:
            plot_reconstructions(model, val_dataloader, device, epoch)

        # Print progress
        print(f"Epoch [{epoch+1}], "
        f"Train Total Loss: {total_train_loss[epoch]:.4f}, "
        f"Train RL Loss: {train_rl_loss[epoch]:.4f}, "
        f"Train KL Loss: {train_kl_loss[epoch]:.4f}, "
        f"Valid Total Loss: {total_val_loss[epoch]:.4f}, "
        f"Valid RL Loss: {val_rl_loss[epoch]:.4f}, "
        f"Valid KL Loss: {val_kl_loss[epoch]:.4f}")
        print('1')
        # Check for early stopping
        # Check for early stopping
        if early_stop_flag:
            # If early stopping is enabled, call the early_stop method of the EarlyStopper object
            # If the method returns True (i.e., the validation loss has not decreased for a specified number of epochs), break the training loop
            if early_stopper.early_stop(valid_loss):             
                break
        else:
            # If early stopping is not enabled, save the model state after each epoch
            torch.save(model.state_dict(), weights_name + f'_{epoch}'+'.pt')

    return train_losses, val_losses,

This is my implementation to VAE but it is not working properly, can you tell me what is the problem?

Thanks :heart:

2 Likes

These are smallish comments, I didn’t read all of it.

  • Why is the Deconv’s Relu commented out ? You do need non-linearities when decoding I’d say, besides the BatchNorm.

  • The “volume” of 256*32*32 is larger than 256*256*3, I’m not sure what is the standard practice but I was expecting that to be somewhat smaller, maybe 128*16*16

The code would benefit from removing what’s commented out, and even the function comments to make it succinct (imho.)

  • Also, you didn’t say what is it that is failing from your understanding, so what is it (not learning, errors out,…)?

any error report or misfunction results apply? what kind of problems?

Thanks a lot for your reply!
I removed the relu from the deconv, cause in the forward method I reimplemented it again to allow the final layer to produce linear outputs as the cross entropy will ably softmax already

1 Like

Makes since, I will recheck.
I have changed the architecture totally

The generated images are so bad