Memory increases drasticaly when running resnet18 in CPU

I running the following code in CPU with torch 1.7.1… My concern is that memory increases with the number of iterations. Reaches about 40GB and I can not complete not even a single epoch. Could you please explain if that is expected or how can we make it more stable?

import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.autograd import Variable
import torchvision.datasets as pydsets
from torchvision.models import resnet18

NUM_FEATURES = {
‘cifar10’: 32*32
}

def main():
model = resnet18()
print(‘Total params: %.2fM’ % (sum(p.numel() for p in model.parameters()) / 1000000.0))

criterion = torch.nn.CrossEntropyLoss()

# optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)
optimizer = torch.optim.Adam(model.parameters())

tr_losses = []
step_times = []
n_iter = 0
n_runs = 1  # For mean efficiency
total_timer.start()
for r in range(n_runs):
    for epoch in range(int(epochs)):
        running_loss = 0
        running_samples = 0
        with tqdm(total=len(train_loader.dataset)) as progressbar:
            for i, (samples, labels) in enumerate(train_loader):
                samples = Variable(samples, requires_grad=True)
                labels = Variable(labels, requires_grad=False)

                def closure(backward=True):
                    if torch.is_grad_enabled() and backward:
                        optimizer.zero_grad()
                    model_outputs = model(samples)
                    cri_loss = criterion(model_outputs, labels)
                    del model_outputs
                    if cri_loss.requires_grad and backward:
                        cri_loss.backward(retain_graph=True, create_graph=True)
                    return cri_loss

                # iter_timer.start()
                tr_loss = optimizer.step(closure=closure)

                # step_times.append(iter_timer.stop(tag='Iteration execution time', verbose=False))
                batch_loss = tr_loss.item()
                running_loss += batch_loss * train_loader.batch_size
                running_samples += train_loader.batch_size


                n_iter += 1
                if n_iter % 100 == 0:
                    with torch.no_grad():
                        correct = 0
                        total = 0
                        for tst_samples, tst_labels in test_loader:
                            tst_samples = Variable(tst_samples)
                            tst_outputs = model(tst_samples)
                            _, predicted = torch.max(tst_outputs.data, 1)
                            total += tst_labels.size(0)
                            correct += (predicted == tst_labels).sum()
                        accuracy = 100 * correct / total
                        print("n_iteration: {}. Tr. Loss: {}. Tst. Accuracy: {}.".format(n_iter,
                                                                                         running_loss / running_samples,
                                                                                         accuracy))

                progressbar.update(labels.size(0))

        tr_losses.append(running_loss / len(train_loader.sampler))

total_timer.stop(tag='Run execution time', verbose=True)
# tr_losses = preprocess(tr_losses)
import numpy as np
# plt.semilogy(np.cumsum(step_times), tr_losses)
plt.plot(tr_losses, 'o-')
# plt.yscale('log')
plt.grid()

plt.show()

if__name__ == ‘main’:
batch_size = 32
epochs = 5
depth = 18

task_name = 'cifar10'

if 'cifar10' in task_name:
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = pydsets.CIFAR10(root='../../../Datasets', train=True, download=False, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    test_dataset = pydsets.CIFAR10(root='../../../Datasets', train=False, download=False, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

else:
    raise Exception('Select an available optimizer')

num_features = NUM_FEATURES[task_name]

main()

Could you why retain_graph=True, create_graph=True are used?
I don’t see a corresponding code snippet clearing the computation graph, so would expect to see an increased memory usage.

Thanks, @ptrblck that was it. Isn’t optimizer.zero_grad() responsible for cleaning up? Could you please provide a reference to get some idea of what you mean?

No, optimizer.zero_grad() will set the .grad attributes of all passed parameters to zeros (or to None if set_to_none=True is used). The computation graph (and thus the intermediate tensors) will be freed during the backward() operation in the default setup (i.e. if retain_graph=False is used).