DDP Training Freezes with Accelerate Library on Multi-GPU Setup

Iā€™m encountering an issue when trying to run distributed training using the Accelerate library from Huggingface. The training process freezes after the dataloader initialization when using multiple GPUs, but works fine on a single GPU.

Environment:

  • Python 3.10
  • PyTorch 2.1
  • Accelerate library from Huggingface
  • Model: DeBERTa-v3-base
  • Kernel version: 5.4.0
  • Using 2 GPUs

Command Used:
accelerate launch --multi_gpu --num_processes=2 --mixed_precision=fp16 main.py

Current Behavior:

  • Training process initializes successfully (model loading, tokenization, and data mapping complete)
  • Process freezes after the dataloader stage
  • No error message is displayed - it simply stops proceeding

Additional Information:

  1. The code successfully runs on a single GPU
  2. Mixed precision (fp16) is enabled
  3. Data preprocessing appears successful (mapping shows 100% completion)
  4. Using NCCL backend for distributed training
root@af4cc4b13b7c:/workspace/embedding_layer# accelerate launch --multi_gpu --num_processes=2 --mixed_precision=fp16 main.py
The following values were not passed to `accelerate launch` and had defaults used instead:
        `--num_machines` was set to a value of `1`
        `--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
01/06/2025 11:56:05 - INFO - __main__ - Distributed environment: MULTI_GPU  Backend: nccl
Num processes: 2
Process index: 1
Local process index: 1
Device: cuda:1

Mixed precision type: fp16

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: jamesjohnson1097 (threado_ml). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.19.1
wandb: Run data is saved locally in /workspace/embedding_layer/wandb/run-20250106_115605-l9j5glhs
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run fresh-sound-5
wandb: ā­ļø View project at https://wandb.ai/threado_ml/embed-train
wandb: šŸš€ View run at https://wandb.ai/threado_ml/embed-train/runs/l9j5glhs
01/06/2025 11:56:06 - INFO - __main__ - Distributed environment: MULTI_GPU  Backend: nccl
Num processes: 2
Process index: 0
Local process index: 0
Device: cuda:0

Mixed precision type: fp16

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
setting the seed 42
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--microsoft--deberta-v3-base/snapshots/8ccc9b6f36199bec6961081d44eb72fb3f7353f3/config.json
Model config DebertaV2Config {
  "_name_or_path": "microsoft/deberta-v3-base",
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-07,
  "legacy": true,
  "max_position_embeddings": 512,
  "max_relative_positions": -1,
  "model_type": "deberta-v2",
  "norm_rel_ebd": "layer_norm",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_dropout": 0,
  "pooler_hidden_act": "gelu",
  "pooler_hidden_size": 768,
  "pos_att_type": [
    "p2c",
    "c2p"
  ],
  "position_biased_input": false,
  "position_buckets": 256,
  "relative_attention": true,
  "share_att_key": true,
  "transformers_version": "4.47.1",
  "type_vocab_size": 0,
  "vocab_size": 128100
}

loading file spm.model from cache at /root/.cache/huggingface/hub/models--microsoft--deberta-v3-base/snapshots/8ccc9b6f36199bec6961081d44eb72fb3f7353f3/spm.model
loading file tokenizer.json from cache at None
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--microsoft--deberta-v3-base/snapshots/8ccc9b6f36199bec6961081d44eb72fb3f7353f3/tokenizer_config.json
loading file chat_template.jinja from cache at None
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--microsoft--deberta-v3-base/snapshots/8ccc9b6f36199bec6961081d44eb72fb3f7353f3/config.json
Model config DebertaV2Config {
  "_name_or_path": "microsoft/deberta-v3-base",
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-07,
  "legacy": true,
  "max_position_embeddings": 512,
  "max_relative_positions": -1,
  "model_type": "deberta-v2",
  "norm_rel_ebd": "layer_norm",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_dropout": 0,
  "pooler_hidden_act": "gelu",
  "pooler_hidden_size": 768,
  "pos_att_type": [
    "p2c",
    "c2p"
  ],
  "position_biased_input": false,
  "position_buckets": 256,
  "relative_attention": true,
  "share_att_key": true,
  "transformers_version": "4.47.1",
  "type_vocab_size": 0,
  "vocab_size": 128100
}

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--microsoft--deberta-v3-base/snapshots/8ccc9b6f36199bec6961081d44eb72fb3f7353f3/config.json
Model config DebertaV2Config {
  "_name_or_path": "microsoft/deberta-v3-base",
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-07,
  "legacy": true,
  "max_position_embeddings": 512,
  "max_relative_positions": -1,
  "model_type": "deberta-v2",
  "norm_rel_ebd": "layer_norm",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_dropout": 0,
  "pooler_hidden_act": "gelu",
  "pooler_hidden_size": 768,
  "pos_att_type": [
    "p2c",
    "c2p"
  ],
  "position_biased_input": false,
  "position_buckets": 256,
  "relative_attention": true,
  "share_att_key": true,
  "transformers_version": "4.47.1",
  "type_vocab_size": 0,
  "vocab_size": 128100
}

/usr/local/lib/python3.10/dist-packages/transformers/convert_slow_tokenizer.py:561: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.
  warnings.warn(
Map: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 165767/165767 [01:24<00:00, 1969.81 examples/s]
Map: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 165767/165767 [00:38<00:00, 4269.71 examples/s]
Map: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1679/1679 [00:00<00:00, 2088.85 examples/s]
Map: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1679/1679 [00:00<00:00, 3674.69 examples/s]

Below is the code:

#main.py

# A script to train the model to identify similar sentences
from ai_dataset import AiDataset
from ai_collator import AiCollatorTrain, AiCollator
from transformers import get_cosine_schedule_with_warmup
from ai_optimizer import get_optimizer
from ai_model import AiModel
from omegaconf import  OmegaConf
from torch.utils.data import DataLoader
import pandas as pd
import os
import torch
from tqdm import tqdm
from torchmetrics import MeanMetric
import numpy as np
import shutil
from accelerate import Accelerator
from accelerate.utils import set_seed
import datasets
import transformers
from accelerate.logging import get_logger
import logging
import time
import wandb


class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

logger = get_logger(__name__)

def get_lr(optimizer):
    return optimizer.param_groups[0]['lr']*1e6

def run_evaluation(model, valid_dl):
    model.eval()
    all_losses = []
    progress_bar = tqdm(len(valid_dl), disable=not accelerator.is_local_main_process)
    for batch in valid_dl:
        with torch.no_grad():
            loss = model(**batch)
        
        batch_losses = accelerator.gather_for_metrics(loss) # tensor of losses from all the gpu
        batch_losses = batch_losses.cpu().numpy().tolist()
        all_losses.extend(batch_losses)
        progress_bar.update(1)
    progress_bar.close()
    
    eval_dict = dict()
    eval_dict['valid_loss'] = np.mean(all_losses)
    return eval_dict

def save_checkpoint(cfg, state, is_best = False):
    os.makedirs(cfg.train_params.output_model_dir, exist_ok=True)
    project_name = 'detect_ai'

    file_name = f"{cfg.train_params.output_model_dir}/{project_name}_{state['epoch']}_{state['step']}.tar"
    torch.save(state, file_name, _use_new_zipfile_serialization = False)

    if is_best:
        shutil.copyfile(file_name,f"{cfg.train_params.output_model_dir}/{project_name}_{state['epoch']}_{state['step']}_best.tar")


def get_latest_checkpoint(output_dir, best_only=True):
    """
    Find the latest checkpoint in the output directory
    Args:
        output_dir: directory containing checkpoints
        best_only: if True, only consider *_best.tar files
    Returns:
        path to latest checkpoint
    """
    files = os.listdir(output_dir)
    if best_only:
        checkpoints = [f for f in files if f.endswith('_best.tar')]
    else:
        checkpoints = [f for f in files if f.endswith('.tar')]
    
    if not checkpoints:
        return None
    # Extract epoch and step numbers
    checkpoint_info = []
    for ckpt in checkpoints:
        parts = ckpt.replace('_best.tar', '').split('_') if 'best' in ckpt else ckpt.replace('.tar', '').split('_')
        epoch, step = int(parts[-2]), int(parts[-1])
        checkpoint_info.append((epoch, step, ckpt))
    
    # Sort by epoch and step
    checkpoint_info.sort(key=lambda x: (x[0], x[1]))
    return os.path.join(output_dir, checkpoint_info[-1][2])

def print_line():
    prefix, unit, suffix = "#", "~~", "#"
    accelerator.print(prefix + unit*50 + suffix)


if __name__ == '__main__':

      # load the config 
    cfg = OmegaConf.load('./conf/cfg.yaml')

    # define the accelerate 
    accelerator = Accelerator(
        log_with='wandb',
        gradient_accumulation_steps=cfg.train_params.gradient_accumulation_steps
    )

    accelerator.init_trackers(
        cfg.wandb.project,
        config= OmegaConf.to_container(cfg=cfg,resolve=True)
    )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)


    # In the main process, we the logging level from warning
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else: # in other process we set the logging level from error 
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    print_line()
    # set the seed 
    accelerator.print(f'setting the seed {cfg.seed}')
    set_seed(cfg.seed)

    # if it is the main process, then create a directory 
    if accelerator.is_local_main_process:
        os.makedirs(cfg.train_params.output_model_dir,exist_ok=True)
    print_line()

    # read the parquet file 
    train_df = pd.read_parquet(os.path.join(cfg.input_data_dir,'train_essays.parquet'))
    valid_df = pd.read_parquet(os.path.join(cfg.input_data_dir,'valid_essays.parquet'))

    # reset the index 
    train_df = train_df.reset_index(drop=True)
    valid_df = valid_df.reset_index(drop=True)

    prompt_ids = train_df["prompt_id"].unique().tolist() # change
    prompt_ids = [p for p in train_df['prompt_id'] if p <= 8]

    pos_df = train_df[train_df['generated'] == 1].copy()
    neg_df = train_df[train_df['generated'] == 0].copy()

    # filter the prompt ids 
    pos_gdf = pos_df.groupby('prompt_id')['id'].apply(list).reset_index()
    prompt2id_pos = dict(zip(pos_gdf['prompt_id'],pos_gdf['id']))

    neg_gdf = neg_df.groupby('prompt_id')['id'].apply(list).reset_index()
    prompt2id_neg = dict(zip(neg_gdf['prompt_id'], neg_gdf['id']))


    with accelerator.main_process_first():
        dataset = AiDataset(cfg=cfg)
        train_ds = dataset.get_dataset(df=train_df)
        valid_ds = dataset.get_dataset(df=valid_df)

    tokenizer = dataset.tokenizer

    # set the data loaders 
    train_ds.set_format(
        type=None,
        columns=[
            'id',
            'input_ids',
            'attention_mask',
            'generated'
        ]
    )


    valid_ds.set_format(
        type=None,
        columns=[
            'id',
            'input_ids',
            'attention_mask',
            'generated'
        ]
    )

    # extract the ids, useful for submissionn 
    valid_ids = valid_ds['id']

    kwargs = dict(
        train_ds=train_ds,
        prompt_ids=prompt_ids,
        prompt2id_pos=prompt2id_pos,
        prompt2id_neg=prompt2id_neg
    )

    # Build the collator for train and valid 
    ai_collator_train = AiCollatorTrain(
        tokenizer=tokenizer,
        pad_to_multiple_of=64,
        kwargs=kwargs
        )

    ai_collator = AiCollator(
        tokenizer=tokenizer,
        pad_to_multiple_of=64
    )

    train_dl = DataLoader(
        train_ds,
        shuffle=True,
        batch_size=cfg.train_params.per_device_train_batch_size,
        pin_memory=True,
        collate_fn=ai_collator_train
    )

    valid_dl = DataLoader(
        valid_ds,
        shuffle=False,
        batch_size=cfg.train_params.per_device_eval_batch_size,
        pin_memory=True,
        collate_fn=ai_collator
    )
    accelerator.print(f'Data preparation is done...')
    print_line()

    # ------ Model -------
    model = AiModel(cfg=cfg,device=accelerator.device) # line changed
    print_line()

    # --- optimizer ------
    optimizer = get_optimizer(cfg=cfg,model=model)
    print_line()

    # --- prepare the model ------
    model, optimizer, train_dl, valid_dl = accelerator.prepare(model, optimizer, train_dl, valid_dl)

    # ----- scheduler 
    num_train_epochs = cfg.train_params.num_train_epochs
    gradient_accumulation_steps = cfg.train_params.gradient_accumulation_steps
    warum_pct = cfg.train_params.warmup_pct
    
    num_update_steps_per_epoch = len(train_dl) // gradient_accumulation_steps
    num_training_steps = num_train_epochs * num_update_steps_per_epoch
    num_warumup_steps = int(warum_pct * num_training_steps)
    

    scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=num_warumup_steps,
        num_training_steps=num_training_steps,
    )

    # -- load the previous checkpoint if any -- 
    checkpoint_path = get_latest_checkpoint(cfg.train_params.output_model_dir)
    if checkpoint_path:
        accelerator.print(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        
        accelerator.wait_for_everyone()

        # Unwrap model before loading state dict
        model = accelerator.unwrap_model(model)
        model.load_state_dict(checkpoint['state_dict'])

        accelerator.print(f"Loaded checkpoint: {checkpoint_path}")
        
        start_epoch = checkpoint['epoch']-1
        current_iteration = checkpoint['step']
        best_lb = checkpoint['lb']
        
        # Adjust scheduler
        for _ in range(current_iteration):
            scheduler.step()

        accelerator.print(f"Skipped Scheduler to current_iteration {current_iteration}")
    else:
        start_epoch = 0
        current_iteration = 0
        best_lb = 1e6


    # --- training setup ----

    patience_tracker = 0

    # ----- training -----
    start_time = time.time()
    accelerator.wait_for_everyone()
    
    progress_bar = None
    for epoch in range(start_epoch, num_train_epochs):
        # reset the progress for every epoch 
        if epoch != 0 and progress_bar != None:
            progress_bar.close()
        
        loss_meter = AverageMeter()
        progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
        model.train()
        for step, batch in enumerate(train_dl):
            with accelerator.accumulate(model): # performs gradient accumulation
                loss = model(**batch)
                accelerator.backward(loss)

            # look for gradient accumulation trigger
                if  accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(),max_norm=1.0)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
            
                loss_meter.update(loss.item())

            if accelerator.sync_gradients:

                progress_bar.set_description(
                    f'STEP: {current_iteration+1:5}/{num_update_steps_per_epoch:5}.'
                    f'LR: {get_lr(optimizer):.4f}.' 
                    f'LOSS: {loss_meter.avg:.4f}.' 
                )
                progress_bar.update(1)
                current_iteration += 1
            
                if cfg.use_wandb:
                    accelerator.log({'train_loss': loss_meter.avg},step=current_iteration)
                    accelerator.log({'lr':get_lr(optimizer=optimizer)}, step=current_iteration) 
            
            if accelerator.sync_gradients and (current_iteration % cfg.train_params.eval_frequency == 0):
                
                model.eval()
                scores_dict = run_evaluation(model,valid_dl)
                lb = scores_dict['valid_loss']
                
                print_line()
                et = time.time() - start_time
                accelerator.print(
                    f">>> Epoch {epoch+1} | Step {step} | Total Step {current_iteration} | Time: {et}"
                )
                print_line()
                accelerator.print(f">>> Current LB (valid_loss) = {round(lb, 4)}")

                is_best = False
                if lb <= best_lb:
                    best_lb = lb 
                    is_best = True 
                    patience_tracker = 0
                else:
                    patience_tracker += 1
                
                accelerator.wait_for_everyone()
                unwraped_model = accelerator.unwrap_model(model)
                model_state = {
                    'step': current_iteration,
                    'epoch': epoch + 1,
                    'state_dict':unwraped_model.state_dict(),
                    'lb':lb 
                }
                if accelerator.is_main_process: # save the checkpoint only in main process
                    save_checkpoint(cfg,state=model_state, is_best=is_best)

                model.train()
                torch.cuda.empty_cache()
                print_line()

                if patience_tracker >= cfg.train_params.patience:
                    print('No improvement in validation loss.')
                    model.eval()
                    accelerator.end_training()
                    break

#optimizer.py

import torch.optim as optim 

'''
identify the parameters of model and group them based on decay and no_decay
'''

def get_optimizer_grouped_parameters_with_llrd(cfg, model):
    no_decay = ['bias','LayerNorm.bias', 'LayerNorm.weight']

    # set the hyperparaemeters for head layer 
    optimizer_grouped_parameters = [{
        'params':[p for n,p in model.named_parameters() if 'backbone' not in n],
        'lr': cfg.optimizer.head_lr,
        'weight_decay': cfg.optimizer.weight_decay
    }]

    layers = [model.backbone.embeddings] + list(model.backbone.encoder.layer) # made wrong
    layers.reverse()
    lr = cfg.optimizer.lr

    for layer in layers:
        lr *= cfg.optimizer.llrd 

        optimizer_grouped_parameters += [
            { # decayable parameters()
                'params': [p for n,p in layer.named_parameters() if not any(nd in n for nd in no_decay)],
                'lr':lr,
                'weight_decay': cfg.optimizer.weight_decay
            },
            { # non decayable parameters()
                'params': [p for n,p in layer.named_parameters() if any(nd in n for nd in no_decay)],
                'lr':lr,
                'weight_decay':0.0
            }
        ]
    return optimizer_grouped_parameters
        

def get_optimizer_grouped_parameters_with_no_llrd(cfg, model):
    no_decay = ['bias','LayerNorm.weight','LayerNorm.bias']
    backbone_params = model.backbone.named_parameters()

    optimizer_grouped_parameters = [
        {
            'params': [p for n,p in model.named_parameters() if 'backbone' not in n],
            'lr':cfg.optimizer.lr,
            'weight_decay': cfg.optimizer.weight_decay
        },
        {
            'params': [p for n,p in backbone_params if not any(nd in n for nd in no_decay)],
            'lr':cfg.optimizer.lr,
            'weight_decay': cfg.optimizer.weight_decay
        },
        {
            'params':[p for n, p in backbone_params if any(nd in n for nd in no_decay)],
            'lr':cfg.optimizer.lr,
            'weight_decay':0.0
        }
    ]
    return optimizer_grouped_parameters


def get_optimizer(cfg, model):

    # configure the optimizer based on learning rate decay 
    optimizer_grouped_parameters = None
    if cfg.optimizer.use_llrd:
        optimizer_grouped_parameters = get_optimizer_grouped_parameters_with_llrd(cfg, model)
    else:
        optimizer_grouped_parameters = get_optimizer_grouped_parameters_with_no_llrd(cfg, model)

    # define the optimizer 
    optimizer_kwargs = {
        'betas': (cfg.optimizer.beta1, cfg.optimizer.beta2),
        'eps': cfg.optimizer.eps,
        'lr': cfg.optimizer.lr
    }
    
    if cfg.optimizer.use_bnb:
        import bitsandbytes as bnb

        return bnb.optim.Adam8bit(
            optimizer_grouped_parameters,
            **optimizer_kwargs
        )
    
    else:

        return optim.AdamW(
            optimizer_grouped_parameters,
            **optimizer_kwargs
        )
#model.py

from transformers import AutoConfig, AutoModel
import torch.nn as nn
import torch.nn.functional as F
import torch

class MeanPooling(nn.Module):

    def __init__(self):
        super(MeanPooling, self).__init__()
    
    def forward(self, last_hidden_state, attention_mask):
        '''
        last_hidden_state :[batch_size, seq_len, hidden_dim]
        attention_mask: [seq_len]
        '''
        input_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask, dim=1)
        sum_attention = input_mask.sum(dim = 1)
        # if it is 0, then we clamp attention mask 
        sum_attention = torch.clamp(sum_attention, min=1e-9)
        mean_embeddings = sum_embeddings / sum_attention
        return mean_embeddings

# define the loss function 
class SupConstrastiveLoss(nn.Module):

    def __init__(self, temperature, device):
        super(SupConstrastiveLoss, self).__init__()
        self.device = device 
        self.temperature = temperature
    
    def forward(self,normalized_embeddings, labels):
        N = normalized_embeddings.size()[0]
        labels = labels.reshape((N,1))
        similarity_mask = torch.ones((N,N)).fill_diagonal_(0).to(device=self.device)

        pos_mask = torch.eq(labels, labels.T).float()
        neg_mask = torch.abs(pos_mask - 1)

        H = torch.matmul(normalized_embeddings,normalized_embeddings.T) * similarity_mask
        H_pos = H * pos_mask
        H_neg = H * neg_mask

        v_pos = torch.mean(torch.exp(torch.div(H_pos,self.temperature)),dim=1)
        v_neg = torch.mean(torch.exp(torch.div(H_neg, self.temperature)),dim=1)

        loss = (-1/N) * torch.sum(torch.log(v_pos/(v_pos + v_neg)))

        return loss

    

class AiModel(nn.Module):

    def __init__(self, cfg, device=None):
        super(AiModel,self).__init__()

        backbone_config = AutoConfig.from_pretrained(cfg.model.backbone_path)
        backbone_config.update(
            {
                'use_cache':False
            }
        )

        self.backbone = AutoModel.from_pretrained(
            cfg.model.backbone_path,
            config = backbone_config
        )

        if cfg.model.gradient_checkpoint:
            self.backbone.gradient_checkpointing_enable()
        

        # define dropout, pool, projection_head 
        self.dropout = nn.Dropout(cfg.model.dropout_rate)
        self.pool = MeanPooling()
        self.loss_fn = SupConstrastiveLoss(temperature=cfg.model.temperature, device=device)
        hidden_size, projection_dim = self.backbone.config.hidden_size, cfg.model.projection_dim
        self.projection_head = nn.Sequential(
            nn.Dropout(cfg.model.dropout_rate),
            nn.Linear(hidden_size, projection_dim),
            nn.ReLU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, input_ids, attention_mask, labels=None):
        
        outputs = self.backbone(input_ids, attention_mask=attention_mask, output_hidden_states=False)
        last_hidden_state = outputs.last_hidden_state

        embeddings = self.pool(last_hidden_state, attention_mask)
        projection_space = self.projection_head(embeddings)
        normalized_proj =  F.normalize(projection_space, dim=-1)
        
        loss = None
        if labels is not None:
            loss = self.loss_fn(normalized_proj, labels)

        return loss

#dataset.py

from transformers import AutoTokenizer
from copy import deepcopy
from datasets import Dataset

class AiDataset:

    def __init__(self, cfg):
        self.cfg = cfg
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model.backbone_path)

    def tokenize(self, examples):
        tz = self.tokenizer(
            examples['text'],
            padding=False,
            truncation=True,
            max_length=self.cfg.model.max_length,
            add_special_tokens=True,
            return_token_type_ids=False
        )
        return tz

    def compute_length(self,examples):
        return {'input_length': [len(x) for x in examples['input_ids']]}

    def get_dataset(self, df):
        dataset = Dataset.from_pandas(df)
        dataset = dataset.map(self.tokenize,batch_size=32, batched=True)
        dataset = dataset.map(self.compute_length,batch_size=32, batched=True)
        return dataset

#collator.py

from transformers import DataCollatorWithPadding
from dataclasses import dataclass, field 
import time 
import os
import random
import torch

@dataclass
class AiCollatorTrain(DataCollatorWithPadding):

    tokenizer = None
    padding = None 
    max_length = None
    pad_to_multiple_of = None
    kwargs: field(default_factory=dict) = None

    def __post_init__(self):
        [setattr(self, k, v) for k,v in self.kwargs.items()]

        # mapping 
        example2idx = dict()
        example_ids = self.train_ds['id']
        for idx in range(len(example_ids)):
            example2idx[example_ids[idx]] = idx

        seed = int(time.time() * 1000) + os.getpid()
        self.rng = random.Random(seed)
        self.example2idx = example2idx
    
    def process_features(self,example_ids):
        examples = []
        for eid in example_ids:
            example = dict()

            record = self.train_ds[self.example2idx[eid]]
            example['id'] = eid
            example['input_ids'] = record['input_ids']
            example['generated'] = record['generated']
            example['attention_mask'] = record['attention_mask']
            examples.append(example)
        return examples

    def __call__(self, features):
        bs = len(features)
        selected_prompt_id = self.rng.choice(self.prompt_ids)
        pos_samples = self.rng.sample(self.prompt2id_pos[selected_prompt_id], k=bs//2)
        neg_samples = self.rng.sample(self.prompt2id_neg[selected_prompt_id], k=bs//2)
        selected_examples = pos_samples + neg_samples
        features = self.process_features(example_ids=selected_examples)

        # extract labels 
        labels = None
        if 'generated' in features[0].keys():
            labels = torch.tensor([feature['generated'] for feature in features],dtype=torch.int64)
        
        features = [
            {
                'input_ids': feature['input_ids'],
                'attention_mask': feature['attention_mask']
            }
            for feature in features
        ]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of= self.pad_to_multiple_of,
            return_tensors=None
        )

        # convert the entries in the batch to a tensor of specified type 
        for key in ['input_ids','attention_mask']:
            batch[key] = torch.tensor(batch[key],dtype=torch.int64)

        batch['labels'] = labels

        return batch


@dataclass
class AiCollator(DataCollatorWithPadding):

    tokenizer = None 
    padding = None 
    max_length = None 
    pad_to_multiple_of = None 

    def __call__(self, features):

        # extract the labels 
        labels = None 
        if 'generated' in features[0].keys():
            labels = torch.tensor([feature['generated'] for feature in features], dtype=torch.int64)

        features = [
            {
                'input_ids':feature['input_ids'],
                'attention_mask': feature['attention_mask']
            }
            for feature in features
        ]

        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length= self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=None
        )

        for key in ['input_ids','attention_mask']:
            batch[key] = torch.tensor(batch[key], dtype=torch.int64)
        
        batch['labels'] = labels

        return batch

Could you try to post this as an issue in accelerate Github?

Sure. I will post in accelerate github

1 Like