I ran the same code on both A100 and 4090 GPUs, but encountered an issue only on the A100.
Problem Description:
On the A100, GPU memory usage on the main device increases gradually with each epoch. For example, at epoch=1
, CUDA memory usage is around 75GB, but by epoch=4
, it grows to 80GB, eventually leading to an Out Of Memory (OOM) error.
However, the same code running on the 4090 does not exhibit this issue — GPU memory remains stable throughout training.
A100:CUDA Version: 12.9,Name: accelerate Version: 1.7.0,Name: torch Version: 2.7.0+cu128;
4090:CUDA Version: 12.4,Name: accelerate Version: 1.3.0;Name: torch Version: 2.5.1;
My Code like that:
import sys
import os
import warnings
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './')))
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CURL_CA_BUNDLE'] = ''
warnings.filterwarnings("ignore")
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report
from config import MyConfig
from model.model import MyModel
from data_loader import MyDataset
from evaluate import compute_acc_f1_recall
config = MyConfig()
def train(epoch, model, optimizer, loss_function, data_loader, accelerator, lr_scheduler):
model.train()
progress_bar = tqdm(total=len(data_loader),
disable=not accelerator.is_main_process,
desc=f"Epoch-TRAIN {epoch}")
for i, batch in enumerate(data_loader):
with accelerator.accumulate(model):
image_data, label_data, text_data, bbox_data, text_padding, bbox_padding = batch
label_data = label_data.to(accelerator.device)
image_data = image_data.to(accelerator.device)
bbox_data = bbox_data.to(accelerator.device)
if config.bbox_embedding:
bbox_padding = bbox_padding.to(accelerator.device)
out = model(image_data, bbox_data, text_data, bbox_padding)
else:
out = model(image_data, bbox_data, text_data)
bbox_padding_mask = bbox_padding.to(dtype=torch.bool, device=accelerator.device)
valid_mask = ~bbox_padding_mask
valid_out = out[valid_mask]
valid_labels = label_data[valid_mask]
valid_labels = torch.argmax(valid_labels, dim=-1)
loss = loss_function(valid_out, valid_labels)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if accelerator.is_main_process:
acc, f1, recall, balanced_acc = compute_acc_f1_recall(preds= valid_out, labels= valid_labels)
progress_bar.update(1)
if accelerator.is_main_process:
logs = {"Train/loss": loss.item(),
"Train/lr": lr_scheduler.get_last_lr()[0],
"Train/ACC": acc,
"Train/F1-Micro": f1,
"Train/Recall": recall,
"Train/ACC-Balance": balanced_acc}
progress_bar.set_postfix(
loss=loss.item(), lr=lr_scheduler.get_last_lr()[0],
acc=acc, f1=f1)
accelerator.log(logs)
def test(epoch, model, loss_function, data_loader, accelerator):
model.eval()
progress_bar = tqdm(total=len(data_loader),
disable=not accelerator.is_main_process,
desc=f"Epoch-TEST {epoch}")
mean_acc, mean_f1 = 0.0, 0.0
with torch.no_grad():
for i, batch in enumerate(data_loader):
image_data, label_data, text_data, bbox_data, text_padding, bbox_padding = batch
label_data = label_data.to(accelerator.device)
image_data = image_data.to(accelerator.device)
bbox_data = bbox_data.to(accelerator.device)
if config.bbox_embedding:
bbox_padding = bbox_padding.to(accelerator.device)
out = model(image_data, bbox_data, text_data, bbox_padding)
else:
out = model(image_data, bbox_data, text_data)
bbox_padding_mask = bbox_padding.to(dtype=torch.bool, device=accelerator.device)
valid_mask = ~bbox_padding_mask
valid_out = out[valid_mask]
valid_labels = label_data[valid_mask]
valid_labels = torch.argmax(valid_labels, dim=-1)
loss = loss_function(valid_out, valid_labels)
acc, f1, recall, balanced_acc = compute_acc_f1_recall(preds= valid_out, labels= valid_labels)
progress_bar.update(1)
logs = {"Test/loss": loss.item(),
"Test/ACC": acc,
"Test/F1-Micro": f1,
"Test/Recall": recall,
"Test/ACC-Balance": balanced_acc}
progress_bar.set_postfix(
loss=loss.item(),
acc=acc, f1=f1)
mean_acc += acc
mean_f1 += f1
accelerator.log(logs)
return mean_acc/len(data_loader), mean_f1/ len(data_loader)
@torch.no_grad
def evaluate(model, pth_path, data_loader, accelerator, data_type: str='test'):
if config.data_type== 'latex':
labels =....
index_to_label = {idx: label for idx, label in enumerate(labels)}
all_preds, all_labels = [], []
if pth_path is not None:
checkpoint = torch.load(pth_path, map_location=accelerator.device, weights_only= False)
try:
model.module.load_state_dict(checkpoint['model_state_dict'])
except Exception:
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
with tqdm(total= len(data_loader), desc=f'Eva-{data_type}') as pbar:
for i, batch in enumerate(data_loader):
image_data, label_data, text_data, bbox_data, text_padding, bbox_padding = batch
label_data = label_data.to(accelerator.device)
image_data = image_data.to(accelerator.device)
bbox_data = bbox_data.to(accelerator.device)
if config.bbox_embedding:
bbox_padding = bbox_padding.to(accelerator.device)
out = model(image_data, bbox_data, text_data, bbox_padding)
else:
out = model(image_data, bbox_data, text_data)
bbox_padding_mask = bbox_padding.to(dtype=torch.bool, device=accelerator.device)
valid_mask = ~bbox_padding_mask
valid_out = out[valid_mask]
valid_labels = label_data[valid_mask]
valid_labels = torch.argmax(valid_labels, dim=-1)
preds = torch.argmax(valid_out, dim=-1)
all_preds.append(preds.detach().cpu().numpy())
all_labels.append(valid_labels.detach().cpu().numpy())
pbar.update(1)
......
def main(image_model_name= None, features_fushion=None):
if image_model_name:
config.image_model_name = image_model_name
config.features_fushion = features_fushion
config.output_dir = f"{config.output_dir}/{config.features_fushion}"
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=False)]
log_writing = "tensorboard" # if config.small_dataset else ["tensorboard", "wandb"]
accelerator = Accelerator(mixed_precision= config.mixed_precision,
gradient_accumulation_steps= config.gradient_accumulation_steps,
log_with= log_writing,
project_dir=os.path.join(config.output_dir, f"logs"),
kwargs_handlers= kwargs_handlers
)
if accelerator.is_main_process:
os.makedirs(config.output_dir, exist_ok=True)
accelerator.init_trackers(f"Train-{config.pred_heads}")
# data
train_dataset = MyDataset(...)
test_dataset = MyDataset(...)
train_dataloader = DataLoader(train_dataset, batch_size= config.batch_size,
collate_fn= train_dataset.collate_fn, num_workers= 4)
test_dataloader = DataLoader(test_dataset, batch_size= config.batch_size,
collate_fn= test_dataset.collate_fn)
# model
model = MyModel(...)
if config.lora:
optimizer = torch.optim.AdamW([
{'params': model.image_model.parameters(), 'lr': 2e-4, 'weight_decay': 1e-2},
{'params': model.text_model.parameters(), 'lr': 4e-5, 'weight_decay': 1e-2},
{'params': [p for n, p in model.named_parameters()
if 'image_model' not in n and 'text_model' not in n]},
], lr= config.learning_rate)
else:
optimizer = torch.optim.AdamW(model.parameters(), lr= config.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer,
start_factor=0.1,
total_iters= 10 * len(train_dataloader))
loss_function = torch.nn.CrossEntropyLoss()
model, optimizer, train_dataloader, test_dataloader, lr_scheduler = accelerator.prepare(model, optimizer,
train_dataloader,
test_dataloader,
lr_scheduler)
best_acc, best_f1 = 0.0, 0.0
for epoch in range(config.epochs):
train(epoch, model, optimizer, loss_function, train_dataloader,
accelerator, lr_scheduler)
if accelerator.is_main_process:
mean_acc, mean_f1 = test(epoch, model,loss_function, test_dataloader, accelerator)
if mean_acc>= best_acc+ 0.01:
best_acc = mean_acc
# model_to_save = accelerator.unwrap_model(model)
model_to_save = accelerator.get_state_dict(model)
accelerator.save({
'model_state_dict': model_to_save,
'optimizer_state_dict': optimizer.state_dict(),
'lr_scheduler_state_dict': lr_scheduler.state_dict(),
'epoch': epoch,
'best_acc': best_acc,
'config': vars(config)
}, f"{config.output_dir}/model_acc_best.pth")
torch.cuda.empty_cache()
if mean_f1>= best_f1+ 0.01:
best_f1 = mean_f1
# model_to_save = accelerator.unwrap_model(model)
model_to_save = accelerator.get_state_dict(model)
accelerator.save({
'model_state_dict': model_to_save,
'optimizer_state_dict': optimizer.state_dict(),
'lr_scheduler_state_dict': lr_scheduler.state_dict(),
'epoch': epoch,
'best_f1': best_f1,
'config': vars(config)
}, f"{config.output_dir}/model_f1_best.pth")
torch.cuda.empty_cache()
if epoch% 5== 0 or epoch== config.epochs- 1:
for _ in [("TEST-ACC", f"{config.output_dir}/model_acc_best.pth"), ("TEST-F1", f"{config.output_dir}/model_f1_best.pth")]:
checkpoint = torch.load(_[1], map_location=accelerator.device, weights_only= False)
try:
model.module.load_state_dict(checkpoint['model_state_dict'])
except Exception:
model.load_state_dict(checkpoint['model_state_dict'])
evaluate(model, None, data_loader= test_dataloader, accelerator=accelerator,
data_type= _[0])
del checkpoint
torch.cuda.empty_cache()
torch.cuda.empty_cache()
accelerator.end_training()
if __name__ == '__main__':
# CUDA_VISIBLE_DEVICES=1,2,3 accelerate launch --num_processes=3 train.py
# CUDA_VISIBLE_DEVICES=2,3 accelerate launch --num_processes=2 train.py
import argparse
parser = argparse.ArgumentParser(description="Run main function with parameters")
parser.add_argument('--image_model_name', type=str, default=None, help='Name of the image model')
parser.add_argument('--features_fusion', type=str, default=None, help='Type of features fusion')
args = parser.parse_args()
main(args.image_model_name, args.features_fusion)