Memory usage increases during training

Hi guys, I am new to PyTorch, and I encountered a problem during training of a language model using PyTorch with CPU.

I monitor the memory usage of the training program using memory-profiler and cat /proc/xxx/status | grep Vm. It seems that the RAM isn’t freed after each epoch ends. Hence, memory usage doesn’t become constant after running first epoch as it should have. Eventually after some epochs, this leads to OOM error on CPU.

The source code is included as follows (the model and datasets should be downloaded automatically if needed):

import torch
import argparse
import os
import logging
import time
from torch import nn
from contextlib import nullcontext
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from transformers import BertTokenizer, BertModel, AdamW

parser = argparse.ArgumentParser(description="PyTorch PERT Example")

parser.add_argument("--batch-size", type=int, default=16, metavar="N",
                    help="input batch size for training (default: 16)")
parser.add_argument("--epochs", type=int, default=1, metavar="N",
                    help="number of epochs to train (default: 10)")
parser.add_argument("--lr", type=float, default=1e-5, metavar="LR",
                    help="learning rate (default: 0.01)")
parser.add_argument("--seed", type=int, default=1, metavar="S",
                    help="random seed (default: 1)")

# Only for test purpose
parser.add_argument("--log-interval", type=int, default=2, metavar="N",
                    help="how many batches to wait before logging training status")

parser.add_argument("--log-path", type=str, default="",
                    help="Path to save logs. Print to StdOut if log-path is not set")

args = parser.parse_args()


device = 'cpu'

# Define dataset, so it is easier to load different split in the dataset
class Dataset(torch.utils.data.Dataset):
    # data_type: which split to load
    def __init__(self, data_type):
        self.data = self.load_data(data_type)

    def load_data(self, data_type):
        tmp_dataset = load_dataset(path='seamew/ChnSentiCorp', split=data_type)
        Data = {}
        for idx, line in enumerate(tmp_dataset):
            sample = line
            Data[idx] = sample
        return Data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


checkpoint = 'hfl/chinese-pert-base'
tokenizer = BertTokenizer.from_pretrained(checkpoint, model_max_length=512)


# Return a batch of data, which is used for training
def collate_fn(batch_samples):
    batch_text = []
    batch_label = []
    for sample in batch_samples:
        batch_text.append(sample['text'])
        batch_label.append(int(sample['label']))
    # The tokenizer will make the data to be a good format for our model to understand
    X = tokenizer(
        batch_text,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    y = torch.tensor(batch_label)
    return X, y


class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.bert_encoder = BertModel.from_pretrained(checkpoint)
        self.classifier = nn.Linear(768, 2)

    def forward(self, x):
        bert_output = self.bert_encoder(**x)
        cls_vectors = bert_output.last_hidden_state[:, 0]
        logits = self.classifier(cls_vectors)
        return logits


def train_loop(args, dataloader, model, loss_fn, optimizer, epoch, total_loss):
    # Set to train mode
    model.train()
    optimizer.zero_grad(set_to_none=True)
    enumerator = enumerate(dataloader, start=1)
    for batch, (X, y) in enumerator:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()

        total_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        if batch % args.log_interval == 0:
            msg = "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
                epoch, batch, len(dataloader),
                100. * batch / len(dataloader), loss.item())
            logging.info(msg)

    return total_loss


def main():
    if args.log_path == "":
        logging.basicConfig(
            format="%(asctime)s %(levelname)-8s %(message)s",
            datefmt="%Y-%m-%dT%H:%M:%SZ",
            level=logging.DEBUG)
    else:
        logging.basicConfig(
            format="%(asctime)s %(levelname)-8s %(message)s",
            datefmt="%Y-%m-%dT%H:%M:%SZ",
            level=logging.DEBUG,
            filename=args.log_path)

    torch.manual_seed(args.seed)

    # Load the data and dataset
    train_data = Dataset('train')
    print("[INFO]Train data length:", len(train_data.data), flush=True)

    train_dataloader = DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

    print("[INFO]Data get loaded successfully", flush=True)

    model = NeuralNetwork().to(device)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=args.lr)

    total_loss = 0.

    for t in range(args.epochs):
        print(f"Epoch {t+1}/{args.epochs + 1}\n-------------------------------")
        start = time.perf_counter()
        total_loss = train_loop(
            args, train_dataloader, model, loss_fn, optimizer, t+1, total_loss)
        end = time.perf_counter()
        print(f"Epoch {t+1}/{args.epochs + 1} Elapsed time:",
              end - start, flush=True)

if __name__ == "__main__":
    main()

I have searched some related posts, but most of them are caused by appending tensors to lists. However I cannot find a related issue for my code.

Thanks in advance!

I don’t see any obvious issues in your code. Are you seeing the increase in memory using the latest stable or nightly release or which version of PyTorch are you using?

Thanks for your reply.

Previously, I am using a self-built PyTorch which should be version 1.13.0.

However, I also do the test with 1.12.0.

Let me do the test again, and see if I can provide more relevant data.

I modify the train_loop part of the code, so that it will break the loop after 50 batches.

def train_loop(args, dataloader, model, loss_fn, optimizer, epoch, total_loss):
    # Set to train mode
    model.train()
    optimizer.zero_grad(set_to_none=True)
    enumerator = enumerate(dataloader, start=1)
    for batch, (X, y) in enumerator:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()

        total_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        if batch % args.log_interval == 0:
            msg = "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
                epoch, batch, len(dataloader),
                100. * batch / len(dataloader), loss.item())
            logging.info(msg)
        if batch > 50:
            break
    return total_loss

Also I monitored the memory usage as follows:

    import tracemalloc

    tracemalloc.start()
    for t in range(args.epochs):
        print(f"Epoch {t+1}/{args.epochs + 1}\n-------------------------------")
        start = time.perf_counter()
        total_loss = train_loop(
            args, train_dataloader, model, loss_fn, optimizer, t+1, total_loss)
        end = time.perf_counter()
        print(f"Epoch {t+1}/{args.epochs + 1} Elapsed time:",
              end - start, flush=True)
        print(f"[INFO] memory usage for epoch {t+1}:", flush=True)
        print(tracemalloc.get_traced_memory())

    tracemalloc.stop()

The result is shown as follows:

tracemalloc.get_traced_memory()
Get the current size and peak size of memory blocks traced by the tracemalloc module as a tuple: (current: int, peak: int).


[INFO] memory usage for epoch 1:
(874691, 1506426)


[INFO] memory usage for epoch 2:
(895483, 1557965)


[INFO] memory usage for epoch 3:
(909171, 1607843)


[INFO] memory usage for epoch 4:
(923043, 1617019)


[INFO] memory usage for epoch 5:
(936723, 1653751)

PyTorch version

 python -c "import torch; print(torch.__version__)"
1.12.0+cu102

A single batch profiling using memory_profiler

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    95  709.438 MiB  709.438 MiB           1   @profile
    96                                         def train_loop(args, dataloader, model, loss_fn, optimizer, epoch, total_loss):
    97                                             # Set to train mode
    98  709.438 MiB    0.000 MiB           1       model.train()
    99  709.438 MiB    0.000 MiB           1       optimizer.zero_grad(set_to_none=True)
   100  709.754 MiB    0.316 MiB           1       enumerator = enumerate(dataloader, start=1)
   101  710.473 MiB    0.719 MiB           1       for batch, (X, y) in enumerator:
   102  710.473 MiB    0.000 MiB           1           X, y = X.to(device), y.to(device)
   103 13665.367 MiB 12954.895 MiB           1           pred = model(X)
   104 13666.074 MiB    0.707 MiB           1           loss = loss_fn(pred, y)
   105 5724.762 MiB -7941.312 MiB           1           loss.backward()
   106
   107 5724.762 MiB    0.000 MiB           1           total_loss += loss.item()
   108 5787.035 MiB   62.273 MiB           1           optimizer.step()
   109 5631.707 MiB -155.328 MiB           1           optimizer.zero_grad(set_to_none=True)
   110
   111 5631.707 MiB    0.000 MiB           1           if batch % args.log_interval == 0:
   112                                                     msg = "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
   113                                                         epoch, batch, len(dataloader),
   114                                                         100. * batch / len(dataloader), loss.item())
   115                                                     logging.info(msg)
   116 5631.707 MiB    0.000 MiB           1           break
   117
   118 5631.707 MiB    0.000 MiB           1       return total_loss


Filename: test.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   121  303.352 MiB  303.352 MiB           1   @profile
   122                                         def main():
   123  303.352 MiB    0.000 MiB           1       if args.log_path == "":
   124  303.352 MiB    0.000 MiB           1           logging.basicConfig(
   125  303.352 MiB    0.000 MiB           1               format="%(asctime)s %(levelname)-8s %(message)s",
   126  303.352 MiB    0.000 MiB           1               datefmt="%Y-%m-%dT%H:%M:%SZ",
   127  303.352 MiB    0.000 MiB           1               level=logging.DEBUG)
   128                                             else:
   129                                                 logging.basicConfig(
   130                                                     format="%(asctime)s %(levelname)-8s %(message)s",
   131                                                     datefmt="%Y-%m-%dT%H:%M:%SZ",
   132                                                     level=logging.DEBUG,
   133                                                     filename=args.log_path)
   134
   135  303.352 MiB    0.000 MiB           1       torch.manual_seed(args.seed)
   136
   137                                             # Load the data and dataset
   138  310.824 MiB    7.473 MiB           1       train_data = Dataset('train')
   139  310.824 MiB    0.000 MiB           1       print("[INFO]Train data length:", len(train_data.data), flush=True)
   140
   141  310.824 MiB    0.000 MiB           1       train_dataloader = DataLoader(
   142  310.824 MiB    0.000 MiB           1           train_data, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
   143
   144  310.824 MiB    0.000 MiB           1       print("[INFO]Data get loaded successfully", flush=True)
   145
   146  709.438 MiB  398.613 MiB           1       model = NeuralNetwork().to(device)
   147
   148  709.438 MiB    0.000 MiB           1       loss_fn = nn.CrossEntropyLoss()
   149  709.438 MiB    0.000 MiB           1       optimizer = AdamW(model.parameters(), lr=args.lr)
   150
   151  709.438 MiB    0.000 MiB           1       total_loss = 0.
   152
   153 5631.707 MiB    0.000 MiB           2       for t in range(args.epochs):
   154  709.438 MiB    0.000 MiB           1           print(f"Epoch {t+1}/{args.epochs + 1}\n-------------------------------")
   155  709.438 MiB    0.000 MiB           1           start = time.perf_counter()
   156  709.438 MiB    0.000 MiB           1           total_loss = train_loop(
   157 5631.707 MiB 5631.707 MiB           1               args, train_dataloader, model, loss_fn, optimizer, t+1, total_loss)
   158 5631.707 MiB    0.000 MiB           1           end = time.perf_counter()
   159 5631.707 MiB    0.000 MiB           1           print(f"Epoch {t+1}/{args.epochs + 1} Elapsed time:",
   160 5631.707 MiB    0.000 MiB           1                 end - start, flush=True)

The memory usage increases around 5Gb after train_loop while I was expecting all the memory is released after train_loop.

I don’t know why it would increase the memory usage since you are properly using the item() method, but could you remove the usage of total_loss here:

total_loss = train_loop(
            args, train_dataloader, model, loss_fn, optimizer, t+1, total_loss)

and see if something changes?

The changed source code:

    tracemalloc.start()
    for t in range(args.epochs):
        print(f"Epoch {t+1}/{args.epochs + 1}\n-------------------------------")
        start = time.perf_counter()
        train_loop(args, train_dataloader, model, loss_fn, optimizer, t+1)
        end = time.perf_counter()
        print(f"Epoch {t+1}/{args.epochs + 1} Elapsed time:",
              end - start, flush=True)
        print(f"[INFO] memory usage for epoch {t+1}:", flush=True)
        print(tracemalloc.get_traced_memory())
def train_loop(args, dataloader, model, loss_fn, optimizer, epoch):
    # Set to train mode
    model.train()
    optimizer.zero_grad(set_to_none=True)
    enumerator = enumerate(dataloader, start=1)
    for batch, (X, y) in enumerator:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        if batch % args.log_interval == 0:
            msg = "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
                epoch, batch, len(dataloader),
                100. * batch / len(dataloader), loss.item())
            logging.info(msg)
        if batch > 50:
            break

The memory usage:

[INFO] memory usage for epoch 1:
(874743, 1506418)


[INFO] memory usage for epoch 2:
(895535, 1557957)

[INFO] memory usage for epoch 3:
(909223, 1607835)

[INFO] memory usage for epoch 4:
(923095, 1617011)


[INFO] memory usage for epoch 5:
(936775, 1653743)

The memory usage is still growing.

I also tried latest release on pypi, it still has this problem.

I’m exactly having the same issue with “1.12.0+cu102”. Any update abou this?

Yes, Still having the issue with “1.10.0+cu102”