Hi guys, I am new to PyTorch, and I encountered a problem during training of a language model using PyTorch with CPU.
I monitor the memory usage of the training program using memory-profiler and cat /proc/xxx/status | grep Vm
. It seems that the RAM isn’t freed after each epoch ends. Hence, memory usage doesn’t become constant after running first epoch as it should have. Eventually after some epochs, this leads to OOM error on CPU.
The source code is included as follows (the model and datasets should be downloaded automatically if needed):
import torch
import argparse
import os
import logging
import time
from torch import nn
from contextlib import nullcontext
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from transformers import BertTokenizer, BertModel, AdamW
parser = argparse.ArgumentParser(description="PyTorch PERT Example")
parser.add_argument("--batch-size", type=int, default=16, metavar="N",
help="input batch size for training (default: 16)")
parser.add_argument("--epochs", type=int, default=1, metavar="N",
help="number of epochs to train (default: 10)")
parser.add_argument("--lr", type=float, default=1e-5, metavar="LR",
help="learning rate (default: 0.01)")
parser.add_argument("--seed", type=int, default=1, metavar="S",
help="random seed (default: 1)")
# Only for test purpose
parser.add_argument("--log-interval", type=int, default=2, metavar="N",
help="how many batches to wait before logging training status")
parser.add_argument("--log-path", type=str, default="",
help="Path to save logs. Print to StdOut if log-path is not set")
args = parser.parse_args()
device = 'cpu'
# Define dataset, so it is easier to load different split in the dataset
class Dataset(torch.utils.data.Dataset):
# data_type: which split to load
def __init__(self, data_type):
self.data = self.load_data(data_type)
def load_data(self, data_type):
tmp_dataset = load_dataset(path='seamew/ChnSentiCorp', split=data_type)
Data = {}
for idx, line in enumerate(tmp_dataset):
sample = line
Data[idx] = sample
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
checkpoint = 'hfl/chinese-pert-base'
tokenizer = BertTokenizer.from_pretrained(checkpoint, model_max_length=512)
# Return a batch of data, which is used for training
def collate_fn(batch_samples):
batch_text = []
batch_label = []
for sample in batch_samples:
batch_text.append(sample['text'])
batch_label.append(int(sample['label']))
# The tokenizer will make the data to be a good format for our model to understand
X = tokenizer(
batch_text,
padding=True,
truncation=True,
return_tensors="pt"
)
y = torch.tensor(batch_label)
return X, y
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.bert_encoder = BertModel.from_pretrained(checkpoint)
self.classifier = nn.Linear(768, 2)
def forward(self, x):
bert_output = self.bert_encoder(**x)
cls_vectors = bert_output.last_hidden_state[:, 0]
logits = self.classifier(cls_vectors)
return logits
def train_loop(args, dataloader, model, loss_fn, optimizer, epoch, total_loss):
# Set to train mode
model.train()
optimizer.zero_grad(set_to_none=True)
enumerator = enumerate(dataloader, start=1)
for batch, (X, y) in enumerator:
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
loss.backward()
total_loss += loss.item()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if batch % args.log_interval == 0:
msg = "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
epoch, batch, len(dataloader),
100. * batch / len(dataloader), loss.item())
logging.info(msg)
return total_loss
def main():
if args.log_path == "":
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.DEBUG)
else:
logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
level=logging.DEBUG,
filename=args.log_path)
torch.manual_seed(args.seed)
# Load the data and dataset
train_data = Dataset('train')
print("[INFO]Train data length:", len(train_data.data), flush=True)
train_dataloader = DataLoader(
train_data, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
print("[INFO]Data get loaded successfully", flush=True)
model = NeuralNetwork().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=args.lr)
total_loss = 0.
for t in range(args.epochs):
print(f"Epoch {t+1}/{args.epochs + 1}\n-------------------------------")
start = time.perf_counter()
total_loss = train_loop(
args, train_dataloader, model, loss_fn, optimizer, t+1, total_loss)
end = time.perf_counter()
print(f"Epoch {t+1}/{args.epochs + 1} Elapsed time:",
end - start, flush=True)
if __name__ == "__main__":
main()
I have searched some related posts, but most of them are caused by appending tensors to lists. However I cannot find a related issue for my code.
Thanks in advance!