For some reason while training my VAE my RAM usage is steadily increasing, and I cannot seem to pin point why.
I have narrowed down the problem to my save_plots
function by using psutil.virtual_memory()
checking my virtual memory between function calls.
Here is the code for the VAE model and initialization of model and training params:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from utils import modelSummary, train_evaluate, plot_training_results
class Encoder(nn.Module):
def __init__(self, latent_dims) -> None:
super(Encoder, self).__init__()
self.conv1 = nn.Conv2d(1, 64, 3, stride = 2, bias = False)
self.batchnorm1 = nn.BatchNorm2d(64)
self.conv2 = nn.Conv2d(64, 128 , 3, stride = 2, bias = False)
self.batchnorm2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 128, 3, stride = 2) # (#num samples, 64 , 2 , 2)
self.flatten = nn.Flatten(start_dim = 1) # (#num samples, 256)
self.linear1 = nn.Linear(512, 1024)
self.mu = nn.Linear(1024, latent_dims)
self.sigma = nn.Linear(1024, latent_dims)
self.N = torch.distributions.Normal(0, 1)
self.N.loc = self.N.loc.cuda()
self.N.scale = self.N.scale.cuda()
self.kl = 0
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.batchnorm1(x)
x = F.relu(self.conv2(x))
x = self.batchnorm2(x)
x = self.conv3(x)
x = self.flatten(x)
x = F.relu(self.linear1(x))
mu = self.mu(x)
sigma = torch.exp(self.sigma(x))
z = mu + sigma * self.N.sample(mu.shape)
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 0.5).sum()
return z
class Decoder(nn.Module):
def __init__(self, latent_dims) -> None:
super(Decoder, self).__init__()
self.linear1 = nn.Linear(latent_dims, 512)
self.deconv1 = nn.ConvTranspose2d(32, 128, 3, stride = 3, padding = 1, output_padding = 2, bias = False)
self.batchnorm1 = nn.BatchNorm2d(128)
self.deconv2 = nn.ConvTranspose2d(128, 64, 3, stride = 2, output_padding = 1, bias = False)
self.batchnorm2 = nn.BatchNorm2d(64)
self.deconv3 = nn.ConvTranspose2d(64, 1, 3)
def forward(self, x):
x = F.relu(self.linear1(x))
x = x.view(-1, 32, 4, 4)
x = F.relu(self.deconv1(x))
x = self.batchnorm1(x)
x = F.relu(self.deconv2(x))
x = self.batchnorm2(x)
x = torch.sigmoid(self.deconv3(x))
return x
class VariationalAutoEncoder(nn.Module):
def __init__(self, latent_dims) -> None:
super(VariationalAutoEncoder, self).__init__()
self.encoder = Encoder(latent_dims)
self.decoder = Decoder(latent_dims)
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
if __name__ == '__main__':
# Initialize Model
latent_dims = 256
model = VariationalAutoEncoder(latent_dims)
modelSummary(model)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\n")
training_params = {
'num_epochs': 200,
'batch_size': 512,
'loss_function':F.mse_loss,
'optimizer': torch.optim.Adam(model.parameters(), lr=1e-4),
'save_path': 'training_256',
'sample_size': 10,
'plot_every': 1,
'latent_dims' : latent_dims
}
# Load Data
train_dataset = DataLoader(torchvision.datasets.MNIST(root = './data', train = True, download = True, transform = torchvision.transforms.ToTensor()), batch_size = training_params['batch_size'])
validation_dataset = DataLoader(torchvision.datasets.MNIST(root = './data', train = False, download = True, transform = torchvision.transforms.ToTensor()), batch_size = training_params['batch_size'])
metrics = {
'l1': lambda output, target: (torch.abs(output - target).sum())
}
train_results, evaluation_results = train_evaluate(model, device, train_dataset, validation_dataset, training_params, metrics)
plot_training_results(train_results=train_results, validation_results=evaluation_results, training_params=training_params, metrics=metrics)
Here is my utils.py
file containing the training loop and other utility functions
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import time
import gc
import numpy as np
import matplotlib.pyplot as plt
def modelSummary(model, verbose=False):
if verbose:
print(model)
total_parameters = 0
for name, param in model.named_parameters():
num_params = param.size()[0]
total_parameters += num_params
if verbose:
print(f"Layer: {name}")
print(f"\tNumber of parameters: {num_params}")
print(f"\tShape: {param.shape}")
if total_parameters > 1e5:
print(f"Total number of parameters: {total_parameters/1e6:.2f}M")
else:
print(f"Total number of parameters: {total_parameters/1e3:.2f}K")
def train_epoch(model: nn.Module, device: torch.device, train_dataloader: DataLoader, training_params: dict, metrics: dict):
"""_summary_
Args:
model (nn.Module): Model to be trained by
device (str): device to be trained on
train_dataloader (nn.data.DataLoader): Dataloader object to load batches of dataset
training_params (dict): Dictionary of training parameters containing "batch_size", "loss_function"
"optimizer".
metrics (dict): Dictionary of functional methods that would compute the metric value
Returns:
run_results (dict): Dictionary of metrics computed for the epoch
"""
OPTIMIZER = training_params["optimizer"]
model = model.to(device)
model.train()
# Dictionary holding result of this epoch
run_results = dict()
for metric in metrics:
run_results[metric] = 0.0
run_results["loss"] = 0.0
# Iterate over batches
num_batches = 0
for x, target in train_dataloader:
num_batches += 1
# Move tensors to device
input = x.to(device)
# Forward pass
output = model(input)
# Compute loss
loss = ((output - input)**2).sum() + model.encoder.kl
# Backward pass
OPTIMIZER.zero_grad()
loss.backward()
OPTIMIZER.step()
# Update metrics
run_results["loss"] += loss.detach().item()
for key, func in metrics.items():
run_results[key] += func(output, input).detach().item()
# Clean up memory
del loss
del input
del output
for key in run_results:
run_results[key] /= num_batches
return run_results
def evaluate_epoch(model: nn.Module, device: torch.device, validation_dataloader: DataLoader, training_params: dict, metrics: dict):
"""_summary_
Args:
model (nn.Module): model to evaluate
device (str): device to evaluate on
validation_dataloader (DataLoader): DataLoader for evaluation
training_params (dict): Dictionary of training parameters containing "batch_size", "loss_function"
"optimizer".
metrics (dict): Dictionary of functional methods that would compute the metric value
Returns:
run_results (dict): Dictionary of metrics computed for the epoch
"""
model = model.to(device)
# Dictionary holding result of this epoch
run_results = dict()
for metric in metrics:
run_results[metric] = 0.0
run_results["loss"] = 0.0
# Iterate over batches
with torch.no_grad():
model.eval()
num_batches = 0
for x, target in validation_dataloader:
num_batches += 1
# Move tensors to device
input = x.to(device)
target = target.to(device)
# Forward pass
output = model(input)
# Compute loss
loss = ((output - input)**2).sum() + model.encoder.kl
# Update metrics
run_results["loss"] += loss.detach().item()
for key, func in metrics.items():
run_results[key] += func(output, input).detach().item()
# Clean up memory
del loss
del input
del output
for key in run_results:
run_results[key] /= num_batches
return run_results
def train_evaluate(model: nn.Module, device: torch.device, train_dataloader: DataLoader, validation_dataloader: DataLoader, training_params: dict, metrics: dict):
"""Function to train a model and provide statistics during training
Args:
model (nn.Module): Model to be trained
device (torch.device): Device to be trained on
train_dataset (DataLoader): Dataset to be trained on
validation_dataset (DataLoader): Dataset to be evaluated on
training_params (dict): Dictionary of training parameters containing "num_epochs", "batch_size", "loss_function",
"save_path", "optimizer"
metrics (dict): Dictionary of functional methods that would compute the metric value
Returns:
_type_: _description_
"""
NUM_EPOCHS = training_params["num_epochs"]
BATCH_SIZE = training_params["batch_size"]
SAVE_PATH = training_params["save_path"]
SAMPLE_SIZE = training_params["sample_size"]
PLOT_EVERY = training_params["plot_every"]
LATENT_DIMS = training_params["latent_dims"]
# Initialize metrics
train_results = dict()
train_results['loss'] = np.empty(1)
evaluation_results = dict()
evaluation_results['loss'] = np.empty(1)
for metric in metrics:
train_results[metric] = np.empty(1)
evaluation_results[metric] = np.empty(1)
batch = next(iter(validation_dataloader))
idxs = []
for i in range(SAMPLE_SIZE):
idx = torch.where(batch[1] == i)[0].squeeze()[0]
idxs.append(idx.item())
FIXED_SAMPLES = batch[0][idxs].to(device).detach()
FIXED_NOISE = torch.normal(0, 1, size = (100, LATENT_DIMS)).to(device).detach()
del idxs
del batch
for epoch in range(NUM_EPOCHS):
start = time.time()
print(f"======== Epoch {epoch+1}/{NUM_EPOCHS} ========")
# Train Model
print("Training ... ")
epoch_train_results = train_epoch(model, device, train_dataloader, training_params, metrics)
# Evaluate Model
print("Evaluating ... ")
epoch_evaluation_results = evaluate_epoch(model, device, validation_dataloader, training_params, metrics)
for metric in metrics:
np.append(train_results[metric], epoch_train_results[metric])
np.append(evaluation_results[metric], epoch_evaluation_results[metric])
# Print results of epoch
print(f"Completed Epoch {epoch+1}/{NUM_EPOCHS} in {(time.time() - start):.2f}s")
print(f"Train Loss: {epoch_train_results['loss']:.2f} \t Validation Loss: {epoch_evaluation_results['loss']:.2f}")
# Plot results
if epoch % PLOT_EVERY == 0:
save_plots(FIXED_SAMPLES, FIXED_NOISE, model, device, epoch, training_params)
print(f"Items cleaned up: {gc.collect()}")
# Save model
SAVE = f"{SAVE_PATH}_epoch{epoch + 1}.pt"
torch.save(model.state_dict(), SAVE)
return train_results, evaluation_results
def save_plots(fixed_samples, fixed_noise, model, device, epoch, training_params):
"""Function to save plots of the model
Args:
fixed_samples (torch.Tensor): Samples to be plotted
fixed_noise (torch.Tensor): Noise to be plotted
model (nn.Module): Model to be tested
epoch (int): Epoch number
SAVE_PATH (str): Path to save plots
"""
SAMPLE_SIZE = training_params["sample_size"]
SAVE_PATH = training_params["save_path"]
with torch.no_grad():
model.eval()
fixed_samples = fixed_samples.to(device)
fixed_noise = fixed_noise.to(device)
outputs = model(fixed_samples)
generated_images = model.decoder(fixed_noise)
fig, ax = plt.subplots(2, SAMPLE_SIZE, figsize=(SAMPLE_SIZE * 5,15))
for i in range(SAMPLE_SIZE):
image = fixed_samples[i].detach().cpu().numpy()
output = outputs[i].detach().cpu().numpy()
ax[0][i].imshow(image.reshape(28,28))
ax[1][i].imshow(output.reshape(28,28))
plt.savefig(f"{SAVE_PATH}/training_images/epoch{epoch + 1}.png")
plt.close()
del fig, ax
del output
del outputs
_, axs = plt.subplots(10, 10, figsize=(30, 20))
axs = axs.flatten()
for image, ax in zip(generated_images, axs):
ax.imshow(image.cpu().numpy().reshape(28, 28))
ax.axis('off')
plt.savefig(f"{SAVE_PATH}/generated_images/epoch{epoch + 1}.png")
plt.close()
# Clean up memory
del generated_images
del image
del _, axs
def plot_training_results(train_results, validation_results, training_params, metrics):
"""Function to plot training results
Args:
train_results (dict): Dictionary of training results
validation_results (dict): Dictionary of validation results
"""
plt.plot(train_results['loss'], label='Training Loss')
plt.plot(validation_results['loss'], label='Validation Loss')
for metric in metrics:
plt.plot(train_results[metric], label=f"Train {metric}")
plt.plot(validation_results[metric], label=f"Validation {metric}")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.savefig(f"{training_params['save_path']}_training_results.png")
plt.show()
if __name__ == '__main__':
pass
Am I doing something wrong while detaching? Or is it a problem with the number of figures I am saving?
On another sidenote, while training the following by running in a terminal and just calling python VAE.py
, I run out of memory due to the steady increase as mentioned above, however if I run it on VSCode it seems to clean up my memory as it nears max, is there any documentation of this or am I mistaken?