Memory issues when running the same code on A100 and 4090Ti when using accelerate

I ran the same code on both A100 and 4090 GPUs, but encountered an issue only on the A100.

Problem Description:
On the A100, GPU memory usage on the main device increases gradually with each epoch. For example, at epoch=1, CUDA memory usage is around 75GB, but by epoch=4, it grows to 80GB, eventually leading to an Out Of Memory (OOM) error.
However, the same code running on the 4090 does not exhibit this issue — GPU memory remains stable throughout training.
A100:CUDA Version: 12.9,Name: accelerate Version: 1.7.0,Name: torch Version: 2.7.0+cu128;
4090:CUDA Version: 12.4,Name: accelerate Version: 1.3.0;Name: torch Version: 2.5.1;
My Code like that:

import sys
import os
import warnings
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), './')))
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CURL_CA_BUNDLE'] = ''
warnings.filterwarnings("ignore")

from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report

from config import MyConfig
from model.model import MyModel
from data_loader import MyDataset
from evaluate import compute_acc_f1_recall

config = MyConfig()

def train(epoch, model, optimizer, loss_function, data_loader, accelerator, lr_scheduler):
    model.train()
    progress_bar = tqdm(total=len(data_loader), 
                            disable=not accelerator.is_main_process, 
                            desc=f"Epoch-TRAIN {epoch}")

    for i, batch in enumerate(data_loader):
        with accelerator.accumulate(model):
            image_data, label_data, text_data, bbox_data, text_padding, bbox_padding = batch
            label_data = label_data.to(accelerator.device)
            image_data = image_data.to(accelerator.device)
            bbox_data = bbox_data.to(accelerator.device)

            if config.bbox_embedding:
                bbox_padding = bbox_padding.to(accelerator.device)
                out = model(image_data, bbox_data, text_data, bbox_padding)
            else:
                out = model(image_data, bbox_data, text_data)

            bbox_padding_mask = bbox_padding.to(dtype=torch.bool, device=accelerator.device)
            valid_mask = ~bbox_padding_mask
            valid_out = out[valid_mask]
            valid_labels = label_data[valid_mask]
            valid_labels = torch.argmax(valid_labels, dim=-1)

            loss = loss_function(valid_out, valid_labels)

            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            if accelerator.is_main_process:
                acc, f1, recall, balanced_acc = compute_acc_f1_recall(preds= valid_out, labels= valid_labels)
        progress_bar.update(1)
        if accelerator.is_main_process:
            logs = {"Train/loss": loss.item(), 
                    "Train/lr": lr_scheduler.get_last_lr()[0],
                    "Train/ACC": acc, 
                    "Train/F1-Micro": f1, 
                    "Train/Recall": recall, 
                    "Train/ACC-Balance": balanced_acc}
            progress_bar.set_postfix(
                loss=loss.item(), lr=lr_scheduler.get_last_lr()[0],
                acc=acc, f1=f1)
            accelerator.log(logs)


def test(epoch, model, loss_function, data_loader, accelerator):
    model.eval()
    progress_bar = tqdm(total=len(data_loader),
                        disable=not accelerator.is_main_process, 
                        desc=f"Epoch-TEST {epoch}")

    mean_acc, mean_f1 = 0.0, 0.0
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            image_data, label_data, text_data, bbox_data, text_padding, bbox_padding = batch
            label_data = label_data.to(accelerator.device)
            image_data = image_data.to(accelerator.device)
            bbox_data = bbox_data.to(accelerator.device)

            if config.bbox_embedding:
                bbox_padding = bbox_padding.to(accelerator.device)
                out = model(image_data, bbox_data, text_data, bbox_padding)
            else:
                out = model(image_data, bbox_data, text_data)

            bbox_padding_mask = bbox_padding.to(dtype=torch.bool, device=accelerator.device)
            valid_mask = ~bbox_padding_mask 
            valid_out = out[valid_mask] 
            valid_labels = label_data[valid_mask]  
            valid_labels = torch.argmax(valid_labels, dim=-1)

            loss = loss_function(valid_out, valid_labels)
            acc, f1, recall, balanced_acc = compute_acc_f1_recall(preds= valid_out, labels= valid_labels)

            progress_bar.update(1)
            logs = {"Test/loss": loss.item(), 
                    "Test/ACC": acc, 
                    "Test/F1-Micro": f1, 
                    "Test/Recall": recall, 
                    "Test/ACC-Balance": balanced_acc}
            progress_bar.set_postfix(
                loss=loss.item(),
                acc=acc, f1=f1)
            mean_acc += acc
            mean_f1 += f1
            accelerator.log(logs)
    return mean_acc/len(data_loader), mean_f1/ len(data_loader)


@torch.no_grad
def evaluate(model, pth_path, data_loader, accelerator, data_type: str='test'):
    if config.data_type== 'latex':
        labels =....
        index_to_label = {idx: label for idx, label in enumerate(labels)}
    all_preds, all_labels = [], []
    if pth_path is not None:
        checkpoint = torch.load(pth_path, map_location=accelerator.device, weights_only= False)
        try:
            model.module.load_state_dict(checkpoint['model_state_dict'])
        except Exception:
            model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    with tqdm(total= len(data_loader), desc=f'Eva-{data_type}') as pbar:
        for i, batch in enumerate(data_loader):
            image_data, label_data, text_data, bbox_data, text_padding, bbox_padding = batch
            label_data = label_data.to(accelerator.device)
            image_data = image_data.to(accelerator.device)
            bbox_data = bbox_data.to(accelerator.device)

            if config.bbox_embedding:
                bbox_padding = bbox_padding.to(accelerator.device)
                out = model(image_data, bbox_data, text_data, bbox_padding)
            else:
                out = model(image_data, bbox_data, text_data)

            bbox_padding_mask = bbox_padding.to(dtype=torch.bool, device=accelerator.device)
            valid_mask = ~bbox_padding_mask 
            valid_out = out[valid_mask]
            valid_labels = label_data[valid_mask]
            valid_labels = torch.argmax(valid_labels, dim=-1)

            preds = torch.argmax(valid_out, dim=-1) 
                
            all_preds.append(preds.detach().cpu().numpy())
            all_labels.append(valid_labels.detach().cpu().numpy())
            pbar.update(1)
    ......

def main(image_model_name= None, features_fushion=None):
    if image_model_name:
        config.image_model_name = image_model_name
        config.features_fushion = features_fushion
        config.output_dir = f"{config.output_dir}/{config.features_fushion}"

    kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=False)]
    log_writing = "tensorboard" # if config.small_dataset else ["tensorboard", "wandb"]
    accelerator = Accelerator(mixed_precision= config.mixed_precision, 
                              gradient_accumulation_steps= config.gradient_accumulation_steps,
                              log_with= log_writing,
                              project_dir=os.path.join(config.output_dir, f"logs"),
                              kwargs_handlers= kwargs_handlers
                              )
    if accelerator.is_main_process:
        os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers(f"Train-{config.pred_heads}")
     
    # data
    train_dataset = MyDataset(...)
    test_dataset = MyDataset(...)
    train_dataloader = DataLoader(train_dataset, batch_size= config.batch_size, 
                                  collate_fn= train_dataset.collate_fn, num_workers= 4)
    test_dataloader = DataLoader(test_dataset, batch_size= config.batch_size,
                                 collate_fn= test_dataset.collate_fn)

    # model
    model = MyModel(...)
    
    if config.lora:
        optimizer = torch.optim.AdamW([
            {'params': model.image_model.parameters(), 'lr': 2e-4, 'weight_decay': 1e-2},
            {'params': model.text_model.parameters(), 'lr': 4e-5, 'weight_decay': 1e-2},
            {'params': [p for n, p in model.named_parameters() 
                if 'image_model' not in n and 'text_model' not in n]},
        ], lr= config.learning_rate)
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr= config.learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer,
                                                     start_factor=0.1,
                                                     total_iters= 10 * len(train_dataloader))
    loss_function = torch.nn.CrossEntropyLoss()

    model, optimizer, train_dataloader, test_dataloader, lr_scheduler = accelerator.prepare(model, optimizer, 
                                                                                            train_dataloader, 
                                                                                            test_dataloader, 
                                                                                            lr_scheduler)
    
    best_acc, best_f1 = 0.0, 0.0
    for epoch in range(config.epochs):
        train(epoch, model, optimizer, loss_function, train_dataloader, 
              accelerator, lr_scheduler)
        if accelerator.is_main_process:
            mean_acc, mean_f1 = test(epoch, model,loss_function, test_dataloader, accelerator)
            if mean_acc>= best_acc+ 0.01:
                best_acc = mean_acc
                # model_to_save = accelerator.unwrap_model(model)
                model_to_save = accelerator.get_state_dict(model)
                accelerator.save({
                    'model_state_dict': model_to_save,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'lr_scheduler_state_dict': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'best_acc': best_acc,
                    'config': vars(config)
                }, f"{config.output_dir}/model_acc_best.pth")
                torch.cuda.empty_cache()
            if mean_f1>= best_f1+ 0.01:
                best_f1 = mean_f1
                # model_to_save = accelerator.unwrap_model(model)
                model_to_save = accelerator.get_state_dict(model)
                accelerator.save({
                    'model_state_dict': model_to_save,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'lr_scheduler_state_dict': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'best_f1': best_f1,
                    'config': vars(config)
                }, f"{config.output_dir}/model_f1_best.pth")
                torch.cuda.empty_cache()

            if epoch% 5== 0 or epoch== config.epochs- 1:
                for _ in [("TEST-ACC", f"{config.output_dir}/model_acc_best.pth"), ("TEST-F1", f"{config.output_dir}/model_f1_best.pth")]:
                    checkpoint = torch.load(_[1], map_location=accelerator.device, weights_only= False)
                    try:
                        model.module.load_state_dict(checkpoint['model_state_dict'])
                    except Exception:
                        model.load_state_dict(checkpoint['model_state_dict'])
                    evaluate(model, None, data_loader= test_dataloader, accelerator=accelerator,
                             data_type= _[0])
                    del checkpoint
                    torch.cuda.empty_cache()
        torch.cuda.empty_cache()
    
    accelerator.end_training()

if __name__ == '__main__':
    # CUDA_VISIBLE_DEVICES=1,2,3 accelerate launch --num_processes=3 train.py
    # CUDA_VISIBLE_DEVICES=2,3 accelerate launch --num_processes=2 train.py
    import argparse
    parser = argparse.ArgumentParser(description="Run main function with parameters")
    parser.add_argument('--image_model_name', type=str, default=None, help='Name of the image model')
    parser.add_argument('--features_fusion', type=str, default=None, help='Type of features fusion')
    args = parser.parse_args()
    main(args.image_model_name, args.features_fusion)

It seems you are using different libs in your envs (PyTorch, accelerate). Do you see the same behavior on your 4090 when updating to the latest stack?

When I run the code on the 4090, the GPU memory usage remains stable at around 29GB throughout training, unlike on the A100 where it keeps increasing over epochs.

Wait I try it