Google Colab running out of ram

I was wondering why my Google Colab is running out of ram when generating an output for my model. Here is the source code

import os
import tarfile
import torch
from torch.utils.data import random_split
import torchvision.transforms as tt
from torchvision.transforms import Compose
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
import shutil
import glob
from random import sample
from torch.utils.data import DataLoader

download_url(url="https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz", root=".")
file = tarfile.open("/content/cifar10.tgz", mode="r")
file.extractall("./data")

data_training = "/content/data/cifar10/train"

if os.path.exists("/content/data/cifar10/validate") is False:
    os.makedirs("/content/data/cifar10/validate")

    os.makedirs("/content/data/cifar10/validate/airplane")

    os.makedirs("/content/data/cifar10/validate/automobile")

    os.makedirs("/content/data/cifar10/validate/bird")

    os.makedirs("/content/data/cifar10/validate/cat")

    os.makedirs("/content/data/cifar10/validate/deer")

    os.makedirs("/content/data/cifar10/validate/dog")

    os.makedirs("/content/data/cifar10/validate/frog")

    os.makedirs("/content/data/cifar10/validate/horse")

    os.makedirs("/content/data/cifar10/validate/ship")

    os.makedirs("/content/data/cifar10/validate/truck")

for i in sample(glob.glob("/content/data/cifar10/train/airplane/*.png"), 500):
    shutil.move(i, "/content/data/cifar10/validate/airplane")

for i in sample(glob.glob("/content/data/cifar10/train/automobile/*.png"), 500):
    shutil.move(i, "/content/data/cifar10/validate/automobile")

for i in sample(glob.glob("/content/data/cifar10/train/bird/*.png"), 500):
    shutil.move(i, "/content/data/cifar10/validate/bird")

for i in sample(glob.glob("/content/data/cifar10/train/cat/*.png"), 500):
    shutil.move(i, "/content/data/cifar10/validate/cat")

for i in sample(glob.glob("/content/data/cifar10/train/deer/*.png"), 500):
    shutil.move(i, "/content/data/cifar10/validate/deer")

for i in sample(glob.glob("/content/data/cifar10/train/dog/*.png"), 500):
    shutil.move(i, "/content/data/cifar10/validate/dog")

for i in sample(glob.glob("/content/data/cifar10/train/frog/*.png"), 500):
    shutil.move(i, "/content/data/cifar10/validate/frog")

for i in sample(glob.glob("/content/data/cifar10/train/horse/*.png"), 500):
    shutil.move(i, "/content/data/cifar10/validate/horse")

for i in sample(glob.glob("/content/data/cifar10/train/ship/*.png"), 500):
    shutil.move(i, "/content/data/cifar10/validate/ship")

for i in sample(glob.glob("/content/data/cifar10/train/truck/*.png"), 500):
    shutil.move(i, "/content/data/cifar10/validate/truck")

train_data_transformation = tt.Compose([
    tt.RandomCrop(32, padding=6, padding_mode="reflect"),
    tt.RandomPerspective(distortion_scale=0.5, p=0.5),
    tt.ToTensor(),
    tt.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010), inplace=True)
])
val_data_transformation = tt.Compose([
    tt.ToTensor(),
    tt.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010), inplace=True)
])

train_ds = ImageFolder(root="/content/data/cifar10/train", transform=train_data_transformation)
val_ds = ImageFolder(root="/content/data/cifar10/validate", transform=val_data_transformation)

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size=128, num_workers=4, pin_memory=True)

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


def accuracy(predictions, labels):
    preds = torch.max(predictions, dim=1)
    return torch.sum(preds == labels).item() / len(preds)


def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


def training_step(train_input, model):
    image, labels = train_input
    output = model(image)
    loss = F.cross_entropy(output, labels)
    calc_loss = loss(output, labels)
    return calc_loss


def validation_loss(val_inputs, model):
    image, labels = val_inputs
    output = model(image)
    loss = F.cross_entropy(output, labels)
    acc = accuracy(output, labels)
    return {"loss": loss, "Accuracy": acc}


def validation_combine_loss(outputs, model):
    loss_accuracy = [validation_loss(batch, model) for batch in outputs]
    extract_loss = [x["loss"] for x in loss_accuracy]
    combining_loss = torch.stack(extract_loss).mean()

    extract_accuracy = [x["Accuracy"] for x in loss_accuracy]
    combining_Accuracy = torch.stack(extract_accuracy).mean()

    return {"Loss": combining_loss, "Accuracy": combining_Accuracy}


def epoch_end(result, epoch):
    print("epoch: {[]}, last_lr {},  Epoch_loss:{}, Epoch_accuracy {}, train_loss {}".format(epoch, result["lrs"], result["Loss"],result["Accuracy"],result["train loss"]))
                                                                                             


def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)


class ResNet9(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))

        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))

        self.classifier = nn.Sequential(nn.MaxPool2d(4),
                                        nn.Flatten(),
                                        nn.Dropout(0.2),
                                        nn.Linear(512, num_classes))

    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out


model = ResNet9(3, 10)
model


def fit(epochs, train_dl, val_dl, model, lr):
    optimizer = torch.optim.Adam(model.parameters(), lr)  # defining the optimizer

    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, lr, epochs,
                                                    steps_per_epoch=len(train_dl))  # learning rate scheduler

    def get_lr():
        for param_group in optimizer.param_groups:  # getting the learning rates of e
            return param_group["lr"]

    history = []

    for epoch in range(epochs):
        model.train()
        train_losss = []
        lrs = []
        for batch in train_dl:
            loss = training_step(train_dl, model)
            train_losss.append(loss)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters, 0.1)
            optimizer.step()
            optimizer.zero_grad()
            lrs.append(get_lr())
            scheduler.step()
        # validation
        results = validation_combine_loss(val_dl, model)
        results["lrs"] = lrs
        results["train loss"] = torch.stack(train_losss).mean()
        string_formatting = epoch_end(results, epoch)
        history.append(string_formatting)
    return history


The entire code you see above work just fine, however upon inputting my validation dataset to test the model out, I receive a message saying “Your session crashed after using all available RAM”.

validation_combine_loss(val_dl, model)

The line of code above is what led to the problem, I have no idea what is wrong

Your validation loop seems to store all losses in a list and is not executed in a with torch.no_grad() context. In this case, each loss tensor will still be attached to the computation graph and you will thus store the entire graph with each loss.
Wrap the validation loop either in the no_grad guard or detach() the loss and accuracy before storing them.

1 Like