For some reason my RAM usage is steadily increasing while training a Variational Autoencoder

For some reason while training my VAE my RAM usage is steadily increasing, and I cannot seem to pin point why.

I have narrowed down the problem to my save_plots function by using psutil.virtual_memory() checking my virtual memory between function calls.

Here is the code for the VAE model and initialization of model and training params:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from utils import modelSummary, train_evaluate, plot_training_results


class Encoder(nn.Module):
    def __init__(self, latent_dims) -> None:
        super(Encoder, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 64, 3, stride = 2, bias = False)
        self.batchnorm1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128 , 3, stride = 2, bias = False)
        self.batchnorm2 = nn.BatchNorm2d(128)
        
        self.conv3 = nn.Conv2d(128, 128, 3, stride = 2) # (#num samples, 64 , 2 , 2)
        
        self.flatten = nn.Flatten(start_dim = 1) # (#num samples, 256)
        
        self.linear1 = nn.Linear(512, 1024)
        
        self.mu = nn.Linear(1024, latent_dims)
        self.sigma = nn.Linear(1024, latent_dims)
        
        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda()
        self.N.scale = self.N.scale.cuda()
        
        self.kl = 0
        

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.batchnorm1(x)
        x = F.relu(self.conv2(x))
        x = self.batchnorm2(x)
        x = self.conv3(x)
        x = self.flatten(x)
        x = F.relu(self.linear1(x))
        
        mu = self.mu(x)
        sigma = torch.exp(self.sigma(x))

        z = mu + sigma * self.N.sample(mu.shape)
        
        self.kl = (sigma**2  + mu**2 - torch.log(sigma) - 0.5).sum()

        return z
    
class Decoder(nn.Module):
    def __init__(self, latent_dims) -> None:
        super(Decoder, self).__init__()

        self.linear1 = nn.Linear(latent_dims, 512)
        
        
        self.deconv1 = nn.ConvTranspose2d(32, 128, 3, stride = 3, padding = 1, output_padding = 2, bias = False)
        self.batchnorm1 = nn.BatchNorm2d(128)
        
        self.deconv2 = nn.ConvTranspose2d(128, 64, 3, stride = 2, output_padding = 1, bias = False)
        self.batchnorm2 = nn.BatchNorm2d(64)
        
        self.deconv3 = nn.ConvTranspose2d(64, 1, 3)
        
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = x.view(-1, 32, 4, 4)
        
        x = F.relu(self.deconv1(x))
        x = self.batchnorm1(x)
        
        x = F.relu(self.deconv2(x))
        x = self.batchnorm2(x)
        
        x = torch.sigmoid(self.deconv3(x))
        
        return x
    
class VariationalAutoEncoder(nn.Module):
    def __init__(self, latent_dims) -> None:
        super(VariationalAutoEncoder, self).__init__()    
        self.encoder = Encoder(latent_dims)
        self.decoder = Decoder(latent_dims)
    
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)
        


if __name__ == '__main__':
    
    # Initialize Model
    latent_dims = 256
    model = VariationalAutoEncoder(latent_dims)
    
    modelSummary(model)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}\n")
    
    training_params = {
        'num_epochs': 200,
        'batch_size': 512,
        'loss_function':F.mse_loss,
        'optimizer': torch.optim.Adam(model.parameters(), lr=1e-4),
        'save_path': 'training_256',
        'sample_size': 10,
        'plot_every': 1,
        'latent_dims' : latent_dims
    }
    
    
    # Load Data
    train_dataset = DataLoader(torchvision.datasets.MNIST(root = './data', train = True, download = True, transform = torchvision.transforms.ToTensor()), batch_size = training_params['batch_size'])
    validation_dataset = DataLoader(torchvision.datasets.MNIST(root = './data', train = False, download = True, transform = torchvision.transforms.ToTensor()), batch_size = training_params['batch_size'])
    
        

    
    metrics = {
        'l1': lambda output, target: (torch.abs(output - target).sum())
    }
    
    train_results, evaluation_results = train_evaluate(model, device, train_dataset, validation_dataset, training_params, metrics)
    plot_training_results(train_results=train_results, validation_results=evaluation_results, training_params=training_params, metrics=metrics)

Here is my utils.py file containing the training loop and other utility functions

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import time
import gc
import numpy as np

import matplotlib.pyplot as plt

def modelSummary(model, verbose=False):
    if verbose:
        print(model)
    
    total_parameters = 0
        
    for name, param in model.named_parameters():
        num_params = param.size()[0]
        total_parameters += num_params
        if verbose:
            print(f"Layer: {name}")
            print(f"\tNumber of parameters: {num_params}")
            print(f"\tShape: {param.shape}")
    
    if total_parameters > 1e5:
        print(f"Total number of parameters: {total_parameters/1e6:.2f}M")
    else:
        print(f"Total number of parameters: {total_parameters/1e3:.2f}K") 

def train_epoch(model: nn.Module, device: torch.device, train_dataloader: DataLoader, training_params: dict, metrics: dict):
    """_summary_

    Args:
        model (nn.Module): Model to be trained by
        device (str): device to be trained on
        train_dataloader (nn.data.DataLoader): Dataloader object to load batches of dataset
        training_params (dict): Dictionary of training parameters containing "batch_size", "loss_function"
                                "optimizer".
        metrics (dict): Dictionary of functional methods that would compute the metric value

    Returns:
        run_results (dict): Dictionary of metrics computed for the epoch
    """
    OPTIMIZER = training_params["optimizer"]
    
    model = model.to(device)
    model.train()
    
    # Dictionary holding result of this epoch
    run_results = dict()
    for metric in metrics:
        run_results[metric] = 0.0
    run_results["loss"] = 0.0
    
    # Iterate over batches
    num_batches = 0
    for x, target in train_dataloader:
        num_batches += 1

        # Move tensors to device
        input = x.to(device)
        
        # Forward pass
        output = model(input)
        
        # Compute loss
        loss = ((output - input)**2).sum() + model.encoder.kl
        
        # Backward pass
        OPTIMIZER.zero_grad()
        loss.backward()
        OPTIMIZER.step()
        
        # Update metrics
        run_results["loss"] += loss.detach().item()
        for key, func in metrics.items():
            run_results[key] += func(output, input).detach().item()
            
        # Clean up memory
        del loss
        del input
        del output
        
    for key in run_results:
        run_results[key] /= num_batches
    
    return run_results


def evaluate_epoch(model: nn.Module, device: torch.device, validation_dataloader: DataLoader, training_params: dict, metrics: dict):
    """_summary_

    Args:
        model (nn.Module): model to evaluate
        device (str): device to evaluate on
        validation_dataloader (DataLoader): DataLoader for evaluation
        training_params (dict): Dictionary of training parameters containing "batch_size", "loss_function"
                                "optimizer".
        metrics (dict): Dictionary of functional methods that would compute the metric value

    Returns:
        run_results (dict): Dictionary of metrics computed for the epoch
    """
    model = model.to(device)
    
    # Dictionary holding result of this epoch
    run_results = dict()
    for metric in metrics:
        run_results[metric] = 0.0
    run_results["loss"] = 0.0
    
    # Iterate over batches
    with torch.no_grad():
        model.eval()
        num_batches = 0
        
        for x, target in validation_dataloader:
            num_batches += 1
            
            
            
            # Move tensors to device
            input = x.to(device)
            target = target.to(device)
            
            # Forward pass
            output = model(input)
            
            # Compute loss
            loss = ((output - input)**2).sum() + model.encoder.kl
            
            # Update metrics
            run_results["loss"] += loss.detach().item()
            for key, func in metrics.items():
                run_results[key] += func(output, input).detach().item()
                
            # Clean up memory
            del loss
            del input
            del output
                
    for key in run_results:
        run_results[key] /= num_batches
        
    return run_results

def train_evaluate(model: nn.Module, device: torch.device, train_dataloader: DataLoader, validation_dataloader: DataLoader, training_params: dict, metrics: dict):
    """Function to train a model and provide statistics during training

    Args:
        model (nn.Module): Model to be trained
        device (torch.device): Device to be trained on
        train_dataset (DataLoader): Dataset to be trained on
        validation_dataset (DataLoader): Dataset to be evaluated on
        training_params (dict): Dictionary of training parameters containing "num_epochs", "batch_size", "loss_function",
                                                                             "save_path", "optimizer"
        metrics (dict): Dictionary of functional methods that would compute the metric value

    Returns:
        _type_: _description_
    """
    NUM_EPOCHS = training_params["num_epochs"]
    BATCH_SIZE = training_params["batch_size"]
    SAVE_PATH = training_params["save_path"]
    SAMPLE_SIZE = training_params["sample_size"]
    PLOT_EVERY = training_params["plot_every"]
    LATENT_DIMS = training_params["latent_dims"]
    
    # Initialize metrics
    train_results = dict()
    train_results['loss'] = np.empty(1)
    evaluation_results = dict()
    evaluation_results['loss'] = np.empty(1)
    
    for metric in metrics:
        train_results[metric] = np.empty(1)
        evaluation_results[metric] = np.empty(1)
    
    batch = next(iter(validation_dataloader))
    idxs = []
    for i in range(SAMPLE_SIZE):
        idx = torch.where(batch[1] == i)[0].squeeze()[0]
        idxs.append(idx.item())
    
    FIXED_SAMPLES = batch[0][idxs].to(device).detach()
   
    FIXED_NOISE = torch.normal(0, 1, size = (100, LATENT_DIMS)).to(device).detach()
    
    del idxs
    del batch
    
    for epoch in range(NUM_EPOCHS):
        start = time.time()
        
        print(f"======== Epoch {epoch+1}/{NUM_EPOCHS} ========")

        # Train Model
        print("Training ... ")
        epoch_train_results = train_epoch(model, device, train_dataloader, training_params, metrics)
        

        # Evaluate Model
        print("Evaluating ... ")
        epoch_evaluation_results = evaluate_epoch(model, device, validation_dataloader, training_params, metrics)
        
        for metric in metrics:
            np.append(train_results[metric], epoch_train_results[metric])
            np.append(evaluation_results[metric], epoch_evaluation_results[metric])
            
        
        # Print results of epoch
        print(f"Completed Epoch {epoch+1}/{NUM_EPOCHS} in {(time.time() - start):.2f}s")
        print(f"Train Loss: {epoch_train_results['loss']:.2f} \t Validation Loss: {epoch_evaluation_results['loss']:.2f}")
        
        # Plot results
        if epoch % PLOT_EVERY == 0:
            save_plots(FIXED_SAMPLES, FIXED_NOISE, model, device, epoch, training_params)
        
        print(f"Items cleaned up: {gc.collect()}")
    
    # Save model
    SAVE = f"{SAVE_PATH}_epoch{epoch + 1}.pt"
    torch.save(model.state_dict(), SAVE)
           
    return train_results, evaluation_results

def save_plots(fixed_samples, fixed_noise, model, device, epoch, training_params):
    """Function to save plots of the model

    Args:
        fixed_samples (torch.Tensor): Samples to be plotted
        fixed_noise (torch.Tensor): Noise to be plotted
        model (nn.Module): Model to be tested
        epoch (int): Epoch number
        SAVE_PATH (str): Path to save plots
    """
    SAMPLE_SIZE = training_params["sample_size"]
    SAVE_PATH = training_params["save_path"]
    
    with torch.no_grad():
        model.eval()

        fixed_samples = fixed_samples.to(device)
        fixed_noise = fixed_noise.to(device)

        outputs = model(fixed_samples)
        generated_images = model.decoder(fixed_noise)
        
        fig, ax = plt.subplots(2, SAMPLE_SIZE, figsize=(SAMPLE_SIZE * 5,15))
        for i in range(SAMPLE_SIZE):
            image = fixed_samples[i].detach().cpu().numpy()
            output = outputs[i].detach().cpu().numpy()
            
            ax[0][i].imshow(image.reshape(28,28))
            ax[1][i].imshow(output.reshape(28,28))
            
        plt.savefig(f"{SAVE_PATH}/training_images/epoch{epoch + 1}.png")
        plt.close()
        
        del fig, ax
        del output
        del outputs
        
        _, axs = plt.subplots(10, 10, figsize=(30, 20))
        axs = axs.flatten()
        
        for image, ax in zip(generated_images, axs):
            ax.imshow(image.cpu().numpy().reshape(28, 28))
            ax.axis('off')
            
        plt.savefig(f"{SAVE_PATH}/generated_images/epoch{epoch + 1}.png")
        plt.close()
        
        # Clean up memory
        del generated_images
        del image
        del _, axs

def plot_training_results(train_results, validation_results, training_params, metrics):
    """Function to plot training results

    Args:
        train_results (dict): Dictionary of training results
        validation_results (dict): Dictionary of validation results
    """
    plt.plot(train_results['loss'], label='Training Loss')
    plt.plot(validation_results['loss'], label='Validation Loss')
    for metric in metrics:
        plt.plot(train_results[metric], label=f"Train {metric}")
        plt.plot(validation_results[metric], label=f"Validation {metric}")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig(f"{training_params['save_path']}_training_results.png")
    plt.show()
       
if __name__ == '__main__':
    pass

Am I doing something wrong while detaching? Or is it a problem with the number of figures I am saving?

On another sidenote, while training the following by running in a terminal and just calling python VAE.py, I run out of memory due to the steady increase as mentioned above, however if I run it on VSCode it seems to clean up my memory as it nears max, is there any documentation of this or am I mistaken?