I am trying to train a CNN using PyTorch in Google Colab, however after around 170 batches Colab freezes because all available RAM (12.5 GB) is used. I’ve looked up online and it seems that other people with this same issue often unintentionally store the computation graph (e.g. of the loss). Checking my code this doesn’t seem the issue, so I was hoping if someone can check if this indeed the case and what the problem is in my code. I am using the following training script with a custom dataloader:
import torch
from torch.utils.data import DataLoader
from torch import optim, nn
from torchsummary import summary
from dataset import AudioDataset
from model import ConvNet
import os
import time
def train(x, y, model, optimizer, loss_function, lr_scheduler=None):
model.train()
y_pred = model(x)
loss = loss_function(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if lr_scheduler:
lr_scheduler.step(optimizer)
return loss.item()
def validate(x, y, model, loss_function):
model.eval()
y_pred = model(x)
loss = loss_function(y_pred, y)
return loss.item()
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
config = {
"learning_rate": 0.02,
"batch_size": 32,
"n_epochs": 30,
}
train_loader = DataLoader(
AudioDataset(os.path.join("audio", "train")), batch_size=config["batch_size"], shuffle=True
)
val_loader = DataLoader(
AudioDataset(os.path.join("audio", "test")), batch_size=config["batch_size"], shuffle=True
)
model = ConvNet()
summary(model, input_size=(1, 173, 64))
loss_fn = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=config["learning_rate"])
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
model.to(device)
# Training loop
for epoch in range(config["n_epochs"]):
print(f"Epoch {epoch+1}/{config['n_epochs']}")
start_time = time.time()
train_loss = 0
train_accuracy = 0
val_loss = 0
val_accuracy = 0
for index, batch in enumerate(train_loader):
if (index + 1) % 10 == 0:
print(f"Batch {index + 1}/{len(train_loader)}")
x_train, y_train = batch
x_train = x_train.to(device)
y_train = y_train.to(device)
train_loss += train(x_train, y_train.float(), model, optimizer, loss_fn)
for index, batch in enumerate(val_loader):
x_val, y_val = batch
x_val = x_val.to(device)
y_val = y_val.to(device)
val_loss += validate(x_val, y_val.float(), model, loss_fn)
train_loss /= len(train_loader)
val_loss /= len(val_loader)
total_time = time.time() - start_time
print(f"{total_time} - loss: {train_loss} - val_loss: {val_loss}")
The amount of memory used seem to suddenly spike and keep increasing, see this graph: