Hello, I wrote the following training script and ran it on a single 40GB A100 for the time being, but even though I am sure the model can fit on the A100 (model.to() works fine), and I can see 28GB using nvidia-smi, when I call FSDP(model), however, it tries to allocate more than 40GB in total. When I use the auto_wrap_policy=fsdp_auto_wrap_policy as an argument, it allocates only an extra 2GB so I figured out a way around the cap, but my impression was that FSDP shouldn’t increase memory usage at all. The same thing happens when using multiple GPUs, each GPU gets the initial 28GB model then tries to allocate more than 12GB more leading to a crash (when using default wrap policy). Why is this?
import os
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
def ip():
import socket
return socket.gethostbyname(socket.gethostname())
def train(rank, world_size):
import torch
from torchinfo import summary
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
#sets up communication channel for FSDP for each process
torch.cuda.set_device(rank)
torch.cuda.empty_cache()
#torchrun sets below envvars
#os.environ['MASTER_ADDR'] = f'{addr}'
#os.environ['MASTER_PORT'] = f'{port}'
print(f"{rank}: ADDR={ip()} MASTER_ADDR={os.environ['MASTER_ADDR']} MASTER_PORT={os.environ['MASTER_PORT']}")
dist.init_process_group("nccl", rank=rank, world_size=world_size) #nccl
#class TextDataset(torch.utils.data.Dataset):
# def __init__(self, text):
# self.text = text
#
# def __getitem__(self, idx):
# return self.text[idx]
#
# def __len__(self):
# return len(self.text)
#
#dataset = TextDataset([i for i in open("dataset.txt")])
from datasets import load_dataset
dataset = load_dataset("text", data_files={"train": "dataset.txt"}, sample_by="paragraph")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", cache_dir="/anvil/scratch/")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", cache_dir="/anvil/scratch/")
#from peft import LoraConfig, get_peft_model
#lora_config = LoraConfig(64)
#peft_model = get_peft_model(model, lora_config)
def tokenize(e):
return tokenizer(e["text"])
dataset_tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"])
def batch(elements):
tokens = [e['input_ids'] for e in elements]
max_len = max([len(tl) for tl in tokens])
pad_token_id = 0
#make all token lists same size
return {
"input_ids":torch.tensor([tl + [pad_token_id]*(max_len-len(tl)) for tl in tokens]),
"attention_mask":torch.tensor([[1]*(len(tl)) + [0]*(max_len-len(tl)) for tl in tokens]),
"labels":torch.tensor([tl + [0]*(max_len-len(tl)) for tl in tokens])
}
sampler = torch.utils.data.distributed.DistributedSampler(dataset_tokenized['train'])
train_loader = torch.utils.data.DataLoader(dataset_tokenized['train'], sampler=sampler, collate_fn=batch)
import functools
import transformers
fsdp_auto_wrap_policy = functools.partial(
torch.distributed.fsdp.wrap.transformer_auto_wrap_policy,
transformer_layer_cls={transformers.models.mistral.modeling_mistral.MistralDecoderLayer}
)
fsdp_model = FSDP(model, device_id=rank) #device_id=rank does the same thing as model.to(rank) before sharing
#args = TrainingArguments(output_dir="tmp_trainer", remove_unused_columns=False, fsdp="full_shard auto_wrap") #fsdp="full_shard auto_wrap"
#trainer = Trainer(model=model, args=args, data_collator=batch, train_dataset=dataset_tokenized['train']) #not moving models to correct device
#trainer.train()
# optimizer = torch.optim.AdamW(fsdp_model.parameters())
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) #decays learning rate every step_size steps
# for epoch in range(1,9):
# fsdp_model.train()
# sampler.set_epoch(epoch)
# for batch in train_loader:
# for key in batch.keys():
# batch[key] = batch[key].to(rank)
# optimizer.zero_grad()
# out = fsdp_model(input_ids=batch["input_ids"],attention_mask=batch["attention_mask"],labels=batch["labels"])
# loss = out['loss']
# loss.backward()
# optimizer.step()
# scheduler.step()
import torch.multiprocessing as mp
if __name__ == '__main__':
# import socket
# def find_free_port():
# with socket.socket() as s:
# s.bind(('', 0))
# return s.getsockname()[1]
# port = find_free_port()
# addr = ip()
import torch
world_size = torch.cuda.device_count()
mp.spawn(train, args=[world_size], nprocs=world_size, join=True)
PS: before I wrote the custom policy, because of the memory constraint, I couldn’t load the model using .to or device_id=rank before using FSDP, instead I opted to use the CPU which wouldn’t cause the crash then the shards would be loaded afterwards