Iām encountering an issue when trying to run distributed training using the Accelerate library from Huggingface. The training process freezes after the dataloader initialization when using multiple GPUs, but works fine on a single GPU.
Environment:
- Python 3.10
- PyTorch 2.1
- Accelerate library from Huggingface
- Model: DeBERTa-v3-base
- Kernel version: 5.4.0
- Using 2 GPUs
Command Used:
accelerate launch --multi_gpu --num_processes=2 --mixed_precision=fp16 main.py
Current Behavior:
- Training process initializes successfully (model loading, tokenization, and data mapping complete)
- Process freezes after the dataloader stage
- No error message is displayed - it simply stops proceeding
Additional Information:
- The code successfully runs on a single GPU
- Mixed precision (fp16) is enabled
- Data preprocessing appears successful (mapping shows 100% completion)
- Using NCCL backend for distributed training
root@af4cc4b13b7c:/workspace/embedding_layer# accelerate launch --multi_gpu --num_processes=2 --mixed_precision=fp16 main.py
The following values were not passed to `accelerate launch` and had defaults used instead:
`--num_machines` was set to a value of `1`
`--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
01/06/2025 11:56:05 - INFO - __main__ - Distributed environment: MULTI_GPU Backend: nccl
Num processes: 2
Process index: 1
Local process index: 1
Device: cuda:1
Mixed precision type: fp16
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: jamesjohnson1097 (threado_ml). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.19.1
wandb: Run data is saved locally in /workspace/embedding_layer/wandb/run-20250106_115605-l9j5glhs
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run fresh-sound-5
wandb: āļø View project at https://wandb.ai/threado_ml/embed-train
wandb: š View run at https://wandb.ai/threado_ml/embed-train/runs/l9j5glhs
01/06/2025 11:56:06 - INFO - __main__ - Distributed environment: MULTI_GPU Backend: nccl
Num processes: 2
Process index: 0
Local process index: 0
Device: cuda:0
Mixed precision type: fp16
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
setting the seed 42
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--microsoft--deberta-v3-base/snapshots/8ccc9b6f36199bec6961081d44eb72fb3f7353f3/config.json
Model config DebertaV2Config {
"_name_or_path": "microsoft/deberta-v3-base",
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-07,
"legacy": true,
"max_position_embeddings": 512,
"max_relative_positions": -1,
"model_type": "deberta-v2",
"norm_rel_ebd": "layer_norm",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"pooler_dropout": 0,
"pooler_hidden_act": "gelu",
"pooler_hidden_size": 768,
"pos_att_type": [
"p2c",
"c2p"
],
"position_biased_input": false,
"position_buckets": 256,
"relative_attention": true,
"share_att_key": true,
"transformers_version": "4.47.1",
"type_vocab_size": 0,
"vocab_size": 128100
}
loading file spm.model from cache at /root/.cache/huggingface/hub/models--microsoft--deberta-v3-base/snapshots/8ccc9b6f36199bec6961081d44eb72fb3f7353f3/spm.model
loading file tokenizer.json from cache at None
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--microsoft--deberta-v3-base/snapshots/8ccc9b6f36199bec6961081d44eb72fb3f7353f3/tokenizer_config.json
loading file chat_template.jinja from cache at None
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--microsoft--deberta-v3-base/snapshots/8ccc9b6f36199bec6961081d44eb72fb3f7353f3/config.json
Model config DebertaV2Config {
"_name_or_path": "microsoft/deberta-v3-base",
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-07,
"legacy": true,
"max_position_embeddings": 512,
"max_relative_positions": -1,
"model_type": "deberta-v2",
"norm_rel_ebd": "layer_norm",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"pooler_dropout": 0,
"pooler_hidden_act": "gelu",
"pooler_hidden_size": 768,
"pos_att_type": [
"p2c",
"c2p"
],
"position_biased_input": false,
"position_buckets": 256,
"relative_attention": true,
"share_att_key": true,
"transformers_version": "4.47.1",
"type_vocab_size": 0,
"vocab_size": 128100
}
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--microsoft--deberta-v3-base/snapshots/8ccc9b6f36199bec6961081d44eb72fb3f7353f3/config.json
Model config DebertaV2Config {
"_name_or_path": "microsoft/deberta-v3-base",
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-07,
"legacy": true,
"max_position_embeddings": 512,
"max_relative_positions": -1,
"model_type": "deberta-v2",
"norm_rel_ebd": "layer_norm",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"pooler_dropout": 0,
"pooler_hidden_act": "gelu",
"pooler_hidden_size": 768,
"pos_att_type": [
"p2c",
"c2p"
],
"position_biased_input": false,
"position_buckets": 256,
"relative_attention": true,
"share_att_key": true,
"transformers_version": "4.47.1",
"type_vocab_size": 0,
"vocab_size": 128100
}
/usr/local/lib/python3.10/dist-packages/transformers/convert_slow_tokenizer.py:561: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.
warnings.warn(
Map: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 165767/165767 [01:24<00:00, 1969.81 examples/s]
Map: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 165767/165767 [00:38<00:00, 4269.71 examples/s]
Map: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 1679/1679 [00:00<00:00, 2088.85 examples/s]
Map: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 1679/1679 [00:00<00:00, 3674.69 examples/s]
Below is the code:
#main.py
# A script to train the model to identify similar sentences
from ai_dataset import AiDataset
from ai_collator import AiCollatorTrain, AiCollator
from transformers import get_cosine_schedule_with_warmup
from ai_optimizer import get_optimizer
from ai_model import AiModel
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
import pandas as pd
import os
import torch
from tqdm import tqdm
from torchmetrics import MeanMetric
import numpy as np
import shutil
from accelerate import Accelerator
from accelerate.utils import set_seed
import datasets
import transformers
from accelerate.logging import get_logger
import logging
import time
import wandb
class AverageMeter(object):
"""Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
logger = get_logger(__name__)
def get_lr(optimizer):
return optimizer.param_groups[0]['lr']*1e6
def run_evaluation(model, valid_dl):
model.eval()
all_losses = []
progress_bar = tqdm(len(valid_dl), disable=not accelerator.is_local_main_process)
for batch in valid_dl:
with torch.no_grad():
loss = model(**batch)
batch_losses = accelerator.gather_for_metrics(loss) # tensor of losses from all the gpu
batch_losses = batch_losses.cpu().numpy().tolist()
all_losses.extend(batch_losses)
progress_bar.update(1)
progress_bar.close()
eval_dict = dict()
eval_dict['valid_loss'] = np.mean(all_losses)
return eval_dict
def save_checkpoint(cfg, state, is_best = False):
os.makedirs(cfg.train_params.output_model_dir, exist_ok=True)
project_name = 'detect_ai'
file_name = f"{cfg.train_params.output_model_dir}/{project_name}_{state['epoch']}_{state['step']}.tar"
torch.save(state, file_name, _use_new_zipfile_serialization = False)
if is_best:
shutil.copyfile(file_name,f"{cfg.train_params.output_model_dir}/{project_name}_{state['epoch']}_{state['step']}_best.tar")
def get_latest_checkpoint(output_dir, best_only=True):
"""
Find the latest checkpoint in the output directory
Args:
output_dir: directory containing checkpoints
best_only: if True, only consider *_best.tar files
Returns:
path to latest checkpoint
"""
files = os.listdir(output_dir)
if best_only:
checkpoints = [f for f in files if f.endswith('_best.tar')]
else:
checkpoints = [f for f in files if f.endswith('.tar')]
if not checkpoints:
return None
# Extract epoch and step numbers
checkpoint_info = []
for ckpt in checkpoints:
parts = ckpt.replace('_best.tar', '').split('_') if 'best' in ckpt else ckpt.replace('.tar', '').split('_')
epoch, step = int(parts[-2]), int(parts[-1])
checkpoint_info.append((epoch, step, ckpt))
# Sort by epoch and step
checkpoint_info.sort(key=lambda x: (x[0], x[1]))
return os.path.join(output_dir, checkpoint_info[-1][2])
def print_line():
prefix, unit, suffix = "#", "~~", "#"
accelerator.print(prefix + unit*50 + suffix)
if __name__ == '__main__':
# load the config
cfg = OmegaConf.load('./conf/cfg.yaml')
# define the accelerate
accelerator = Accelerator(
log_with='wandb',
gradient_accumulation_steps=cfg.train_params.gradient_accumulation_steps
)
accelerator.init_trackers(
cfg.wandb.project,
config= OmegaConf.to_container(cfg=cfg,resolve=True)
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
# In the main process, we the logging level from warning
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else: # in other process we set the logging level from error
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
print_line()
# set the seed
accelerator.print(f'setting the seed {cfg.seed}')
set_seed(cfg.seed)
# if it is the main process, then create a directory
if accelerator.is_local_main_process:
os.makedirs(cfg.train_params.output_model_dir,exist_ok=True)
print_line()
# read the parquet file
train_df = pd.read_parquet(os.path.join(cfg.input_data_dir,'train_essays.parquet'))
valid_df = pd.read_parquet(os.path.join(cfg.input_data_dir,'valid_essays.parquet'))
# reset the index
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
prompt_ids = train_df["prompt_id"].unique().tolist() # change
prompt_ids = [p for p in train_df['prompt_id'] if p <= 8]
pos_df = train_df[train_df['generated'] == 1].copy()
neg_df = train_df[train_df['generated'] == 0].copy()
# filter the prompt ids
pos_gdf = pos_df.groupby('prompt_id')['id'].apply(list).reset_index()
prompt2id_pos = dict(zip(pos_gdf['prompt_id'],pos_gdf['id']))
neg_gdf = neg_df.groupby('prompt_id')['id'].apply(list).reset_index()
prompt2id_neg = dict(zip(neg_gdf['prompt_id'], neg_gdf['id']))
with accelerator.main_process_first():
dataset = AiDataset(cfg=cfg)
train_ds = dataset.get_dataset(df=train_df)
valid_ds = dataset.get_dataset(df=valid_df)
tokenizer = dataset.tokenizer
# set the data loaders
train_ds.set_format(
type=None,
columns=[
'id',
'input_ids',
'attention_mask',
'generated'
]
)
valid_ds.set_format(
type=None,
columns=[
'id',
'input_ids',
'attention_mask',
'generated'
]
)
# extract the ids, useful for submissionn
valid_ids = valid_ds['id']
kwargs = dict(
train_ds=train_ds,
prompt_ids=prompt_ids,
prompt2id_pos=prompt2id_pos,
prompt2id_neg=prompt2id_neg
)
# Build the collator for train and valid
ai_collator_train = AiCollatorTrain(
tokenizer=tokenizer,
pad_to_multiple_of=64,
kwargs=kwargs
)
ai_collator = AiCollator(
tokenizer=tokenizer,
pad_to_multiple_of=64
)
train_dl = DataLoader(
train_ds,
shuffle=True,
batch_size=cfg.train_params.per_device_train_batch_size,
pin_memory=True,
collate_fn=ai_collator_train
)
valid_dl = DataLoader(
valid_ds,
shuffle=False,
batch_size=cfg.train_params.per_device_eval_batch_size,
pin_memory=True,
collate_fn=ai_collator
)
accelerator.print(f'Data preparation is done...')
print_line()
# ------ Model -------
model = AiModel(cfg=cfg,device=accelerator.device) # line changed
print_line()
# --- optimizer ------
optimizer = get_optimizer(cfg=cfg,model=model)
print_line()
# --- prepare the model ------
model, optimizer, train_dl, valid_dl = accelerator.prepare(model, optimizer, train_dl, valid_dl)
# ----- scheduler
num_train_epochs = cfg.train_params.num_train_epochs
gradient_accumulation_steps = cfg.train_params.gradient_accumulation_steps
warum_pct = cfg.train_params.warmup_pct
num_update_steps_per_epoch = len(train_dl) // gradient_accumulation_steps
num_training_steps = num_train_epochs * num_update_steps_per_epoch
num_warumup_steps = int(warum_pct * num_training_steps)
scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=num_warumup_steps,
num_training_steps=num_training_steps,
)
# -- load the previous checkpoint if any --
checkpoint_path = get_latest_checkpoint(cfg.train_params.output_model_dir)
if checkpoint_path:
accelerator.print(f"Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
accelerator.wait_for_everyone()
# Unwrap model before loading state dict
model = accelerator.unwrap_model(model)
model.load_state_dict(checkpoint['state_dict'])
accelerator.print(f"Loaded checkpoint: {checkpoint_path}")
start_epoch = checkpoint['epoch']-1
current_iteration = checkpoint['step']
best_lb = checkpoint['lb']
# Adjust scheduler
for _ in range(current_iteration):
scheduler.step()
accelerator.print(f"Skipped Scheduler to current_iteration {current_iteration}")
else:
start_epoch = 0
current_iteration = 0
best_lb = 1e6
# --- training setup ----
patience_tracker = 0
# ----- training -----
start_time = time.time()
accelerator.wait_for_everyone()
progress_bar = None
for epoch in range(start_epoch, num_train_epochs):
# reset the progress for every epoch
if epoch != 0 and progress_bar != None:
progress_bar.close()
loss_meter = AverageMeter()
progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
model.train()
for step, batch in enumerate(train_dl):
with accelerator.accumulate(model): # performs gradient accumulation
loss = model(**batch)
accelerator.backward(loss)
# look for gradient accumulation trigger
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(),max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
loss_meter.update(loss.item())
if accelerator.sync_gradients:
progress_bar.set_description(
f'STEP: {current_iteration+1:5}/{num_update_steps_per_epoch:5}.'
f'LR: {get_lr(optimizer):.4f}.'
f'LOSS: {loss_meter.avg:.4f}.'
)
progress_bar.update(1)
current_iteration += 1
if cfg.use_wandb:
accelerator.log({'train_loss': loss_meter.avg},step=current_iteration)
accelerator.log({'lr':get_lr(optimizer=optimizer)}, step=current_iteration)
if accelerator.sync_gradients and (current_iteration % cfg.train_params.eval_frequency == 0):
model.eval()
scores_dict = run_evaluation(model,valid_dl)
lb = scores_dict['valid_loss']
print_line()
et = time.time() - start_time
accelerator.print(
f">>> Epoch {epoch+1} | Step {step} | Total Step {current_iteration} | Time: {et}"
)
print_line()
accelerator.print(f">>> Current LB (valid_loss) = {round(lb, 4)}")
is_best = False
if lb <= best_lb:
best_lb = lb
is_best = True
patience_tracker = 0
else:
patience_tracker += 1
accelerator.wait_for_everyone()
unwraped_model = accelerator.unwrap_model(model)
model_state = {
'step': current_iteration,
'epoch': epoch + 1,
'state_dict':unwraped_model.state_dict(),
'lb':lb
}
if accelerator.is_main_process: # save the checkpoint only in main process
save_checkpoint(cfg,state=model_state, is_best=is_best)
model.train()
torch.cuda.empty_cache()
print_line()
if patience_tracker >= cfg.train_params.patience:
print('No improvement in validation loss.')
model.eval()
accelerator.end_training()
break
#optimizer.py
import torch.optim as optim
'''
identify the parameters of model and group them based on decay and no_decay
'''
def get_optimizer_grouped_parameters_with_llrd(cfg, model):
no_decay = ['bias','LayerNorm.bias', 'LayerNorm.weight']
# set the hyperparaemeters for head layer
optimizer_grouped_parameters = [{
'params':[p for n,p in model.named_parameters() if 'backbone' not in n],
'lr': cfg.optimizer.head_lr,
'weight_decay': cfg.optimizer.weight_decay
}]
layers = [model.backbone.embeddings] + list(model.backbone.encoder.layer) # made wrong
layers.reverse()
lr = cfg.optimizer.lr
for layer in layers:
lr *= cfg.optimizer.llrd
optimizer_grouped_parameters += [
{ # decayable parameters()
'params': [p for n,p in layer.named_parameters() if not any(nd in n for nd in no_decay)],
'lr':lr,
'weight_decay': cfg.optimizer.weight_decay
},
{ # non decayable parameters()
'params': [p for n,p in layer.named_parameters() if any(nd in n for nd in no_decay)],
'lr':lr,
'weight_decay':0.0
}
]
return optimizer_grouped_parameters
def get_optimizer_grouped_parameters_with_no_llrd(cfg, model):
no_decay = ['bias','LayerNorm.weight','LayerNorm.bias']
backbone_params = model.backbone.named_parameters()
optimizer_grouped_parameters = [
{
'params': [p for n,p in model.named_parameters() if 'backbone' not in n],
'lr':cfg.optimizer.lr,
'weight_decay': cfg.optimizer.weight_decay
},
{
'params': [p for n,p in backbone_params if not any(nd in n for nd in no_decay)],
'lr':cfg.optimizer.lr,
'weight_decay': cfg.optimizer.weight_decay
},
{
'params':[p for n, p in backbone_params if any(nd in n for nd in no_decay)],
'lr':cfg.optimizer.lr,
'weight_decay':0.0
}
]
return optimizer_grouped_parameters
def get_optimizer(cfg, model):
# configure the optimizer based on learning rate decay
optimizer_grouped_parameters = None
if cfg.optimizer.use_llrd:
optimizer_grouped_parameters = get_optimizer_grouped_parameters_with_llrd(cfg, model)
else:
optimizer_grouped_parameters = get_optimizer_grouped_parameters_with_no_llrd(cfg, model)
# define the optimizer
optimizer_kwargs = {
'betas': (cfg.optimizer.beta1, cfg.optimizer.beta2),
'eps': cfg.optimizer.eps,
'lr': cfg.optimizer.lr
}
if cfg.optimizer.use_bnb:
import bitsandbytes as bnb
return bnb.optim.Adam8bit(
optimizer_grouped_parameters,
**optimizer_kwargs
)
else:
return optim.AdamW(
optimizer_grouped_parameters,
**optimizer_kwargs
)
#model.py
from transformers import AutoConfig, AutoModel
import torch.nn as nn
import torch.nn.functional as F
import torch
class MeanPooling(nn.Module):
def __init__(self):
super(MeanPooling, self).__init__()
def forward(self, last_hidden_state, attention_mask):
'''
last_hidden_state :[batch_size, seq_len, hidden_dim]
attention_mask: [seq_len]
'''
input_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_embeddings = torch.sum(last_hidden_state * input_mask, dim=1)
sum_attention = input_mask.sum(dim = 1)
# if it is 0, then we clamp attention mask
sum_attention = torch.clamp(sum_attention, min=1e-9)
mean_embeddings = sum_embeddings / sum_attention
return mean_embeddings
# define the loss function
class SupConstrastiveLoss(nn.Module):
def __init__(self, temperature, device):
super(SupConstrastiveLoss, self).__init__()
self.device = device
self.temperature = temperature
def forward(self,normalized_embeddings, labels):
N = normalized_embeddings.size()[0]
labels = labels.reshape((N,1))
similarity_mask = torch.ones((N,N)).fill_diagonal_(0).to(device=self.device)
pos_mask = torch.eq(labels, labels.T).float()
neg_mask = torch.abs(pos_mask - 1)
H = torch.matmul(normalized_embeddings,normalized_embeddings.T) * similarity_mask
H_pos = H * pos_mask
H_neg = H * neg_mask
v_pos = torch.mean(torch.exp(torch.div(H_pos,self.temperature)),dim=1)
v_neg = torch.mean(torch.exp(torch.div(H_neg, self.temperature)),dim=1)
loss = (-1/N) * torch.sum(torch.log(v_pos/(v_pos + v_neg)))
return loss
class AiModel(nn.Module):
def __init__(self, cfg, device=None):
super(AiModel,self).__init__()
backbone_config = AutoConfig.from_pretrained(cfg.model.backbone_path)
backbone_config.update(
{
'use_cache':False
}
)
self.backbone = AutoModel.from_pretrained(
cfg.model.backbone_path,
config = backbone_config
)
if cfg.model.gradient_checkpoint:
self.backbone.gradient_checkpointing_enable()
# define dropout, pool, projection_head
self.dropout = nn.Dropout(cfg.model.dropout_rate)
self.pool = MeanPooling()
self.loss_fn = SupConstrastiveLoss(temperature=cfg.model.temperature, device=device)
hidden_size, projection_dim = self.backbone.config.hidden_size, cfg.model.projection_dim
self.projection_head = nn.Sequential(
nn.Dropout(cfg.model.dropout_rate),
nn.Linear(hidden_size, projection_dim),
nn.ReLU(),
nn.Linear(projection_dim, projection_dim)
)
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.backbone(input_ids, attention_mask=attention_mask, output_hidden_states=False)
last_hidden_state = outputs.last_hidden_state
embeddings = self.pool(last_hidden_state, attention_mask)
projection_space = self.projection_head(embeddings)
normalized_proj = F.normalize(projection_space, dim=-1)
loss = None
if labels is not None:
loss = self.loss_fn(normalized_proj, labels)
return loss
#dataset.py
from transformers import AutoTokenizer
from copy import deepcopy
from datasets import Dataset
class AiDataset:
def __init__(self, cfg):
self.cfg = cfg
self.tokenizer = AutoTokenizer.from_pretrained(cfg.model.backbone_path)
def tokenize(self, examples):
tz = self.tokenizer(
examples['text'],
padding=False,
truncation=True,
max_length=self.cfg.model.max_length,
add_special_tokens=True,
return_token_type_ids=False
)
return tz
def compute_length(self,examples):
return {'input_length': [len(x) for x in examples['input_ids']]}
def get_dataset(self, df):
dataset = Dataset.from_pandas(df)
dataset = dataset.map(self.tokenize,batch_size=32, batched=True)
dataset = dataset.map(self.compute_length,batch_size=32, batched=True)
return dataset
#collator.py
from transformers import DataCollatorWithPadding
from dataclasses import dataclass, field
import time
import os
import random
import torch
@dataclass
class AiCollatorTrain(DataCollatorWithPadding):
tokenizer = None
padding = None
max_length = None
pad_to_multiple_of = None
kwargs: field(default_factory=dict) = None
def __post_init__(self):
[setattr(self, k, v) for k,v in self.kwargs.items()]
# mapping
example2idx = dict()
example_ids = self.train_ds['id']
for idx in range(len(example_ids)):
example2idx[example_ids[idx]] = idx
seed = int(time.time() * 1000) + os.getpid()
self.rng = random.Random(seed)
self.example2idx = example2idx
def process_features(self,example_ids):
examples = []
for eid in example_ids:
example = dict()
record = self.train_ds[self.example2idx[eid]]
example['id'] = eid
example['input_ids'] = record['input_ids']
example['generated'] = record['generated']
example['attention_mask'] = record['attention_mask']
examples.append(example)
return examples
def __call__(self, features):
bs = len(features)
selected_prompt_id = self.rng.choice(self.prompt_ids)
pos_samples = self.rng.sample(self.prompt2id_pos[selected_prompt_id], k=bs//2)
neg_samples = self.rng.sample(self.prompt2id_neg[selected_prompt_id], k=bs//2)
selected_examples = pos_samples + neg_samples
features = self.process_features(example_ids=selected_examples)
# extract labels
labels = None
if 'generated' in features[0].keys():
labels = torch.tensor([feature['generated'] for feature in features],dtype=torch.int64)
features = [
{
'input_ids': feature['input_ids'],
'attention_mask': feature['attention_mask']
}
for feature in features
]
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of= self.pad_to_multiple_of,
return_tensors=None
)
# convert the entries in the batch to a tensor of specified type
for key in ['input_ids','attention_mask']:
batch[key] = torch.tensor(batch[key],dtype=torch.int64)
batch['labels'] = labels
return batch
@dataclass
class AiCollator(DataCollatorWithPadding):
tokenizer = None
padding = None
max_length = None
pad_to_multiple_of = None
def __call__(self, features):
# extract the labels
labels = None
if 'generated' in features[0].keys():
labels = torch.tensor([feature['generated'] for feature in features], dtype=torch.int64)
features = [
{
'input_ids':feature['input_ids'],
'attention_mask': feature['attention_mask']
}
for feature in features
]
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length= self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=None
)
for key in ['input_ids','attention_mask']:
batch[key] = torch.tensor(batch[key], dtype=torch.int64)
batch['labels'] = labels
return batch