Different performance between deepspeed and fsdp

Recently I’m working on training large model using FSDP and deepspeed. However when using similar config the performance of memory using is different.

Deepspeed:

import time
import torch
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

import deepspeed

DEEPSPEED_CONFIG = \
{
    'fp16': {
        'enabled': True,
        },
    'optimizer': {
        'type': 'AdamW',
        'params': {
            'lr': 1e-05, 'betas': [0.9, 0.999],
            'eps': 1e-08, 'weight_decay': 0.0
        }
     },
    'scheduler': {'type': 'WarmupLR', 'params': {'warmup_min_lr': 0, 'warmup_max_lr': 1e-05, 'warmup_num_steps': 100}},
    'zero_optimization': {
        'stage': 3,
        'offload_optimizer': {'device': 'cpu', 'pin_memory': False},
        'offload_param': {'device': 'cpu', 'pin_memory': False},
    },
    'train_batch_size': 12,
    'train_micro_batch_size_per_gpu': 2,
    'gradient_accumulation_steps': 2,
}


MODEL_NAME = 'Salesforce/codegen-2B-mono'
epochs = 10


print(f'initializing model: {MODEL_NAME}')

config = AutoConfig.from_pretrained(MODEL_NAME)
config.gradient_checkpointing = True
config.use_cache = False

start = time.time()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config)
print(f"model loaded: {time.time() - start}s")

model.train()

print('initializing deepspeed')

model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer = AdamW(model_parameters, lr=1e-5)
model_engine, optimizer, _, _ = deepspeed.initialize(
	config=DEEPSPEED_CONFIG, model=model, model_parameters=model_parameters,
    optimizer=optimizer)

torch.cuda.empty_cache()

print('starting training')

sentences = ["# this function prints hello world" for _ in range(DEEPSPEED_CONFIG['train_micro_batch_size_per_gpu'])]
inputs = tokenizer(sentences, return_tensors="pt")

for step in range(epochs):

	loss = model_engine(
		input_ids=inputs['input_ids'].cuda(),
		labels=inputs['input_ids'].cuda(),
	).loss

	model_engine.backward(loss)
	model_engine.step()

	print(f'{step} {loss:8.3f}')

In this case, after model initialization, it will use less than 10k MB per gpu when using 3 gpus.

Fairscalr:

import os
import sys
import time

import torch
from torch.optim import AdamW
import torch.distributed as dist

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from fairscale.nn.data_parallel import (
    FullyShardedDataParallel as FSDP,
    ShardedDataParallel as SDP,
)
from fairscale.optim.oss import OSS
    
MODEL_NAME = "Salesforce/codegen-2B-mono"
epochs = 5

dist.init_process_group('nccl')
local_rank = int(os.getenv("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)

config = AutoConfig.from_pretrained(MODEL_NAME)
config.gradient_checkpointing = True
config.use_cache = False

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f"Rank {local_rank}: Tokenizer Loaded.")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, config=config)
print(f"Rank {local_rank}: Model Loaded.")
dist.barrier()

model = FSDP(
	model,
	mixed_precision=True,
	move_params_to_cpu=True)
print(f"Rank {local_rank}: Model Wrapped.")

model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
optimzer = OSS(model_parameters, AdamW, broadcast_fp16=True)

torch.cuda.empty_cache()
sentences = ["# this function prints hello world" for _ in range(2)]
inputs = tokenizer(sentences, return_tensors="pt")

for step in range(epochs):

	loss = model(
		input_ids=inputs['input_ids'].cuda(),
		labels=inputs['input_ids'].cuda(),
	).loss
	
	loss.backward()
	optimizer.zero_grad()
	optimizer.step()

	print(f'{step} {loss:8.3f}')

However in this case, it will cost up to 30k MB per gpu.

So is there anything wrong with my fairscale config? I will be very appreciate if someone can told me!

By the way, I got this when using OSS:

Traceback (most recent call last):
  File "codegen_fairscale.py", line 61, in <module>
    optimizer.step()
  File "/nvme/xingshuhao.dispatch/anaconda3/envs/fairscale/lib/python3.8/site-packages/torch/optim/optimizer.py", line 140, in wrapper
    out = func(*args, **kwargs)
  File "/nvme/xingshuhao.dispatch/anaconda3/envs/fairscale/lib/python3.8/site-packages/fairscale/optim/oss.py", line 237, in step
    self._broadcast_params()
  File "/nvme/xingshuhao.dispatch/anaconda3/envs/fairscale/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/nvme/xingshuhao.dispatch/anaconda3/envs/fairscale/lib/python3.8/site-packages/fairscale/optim/oss.py", line 614, in _broadcast_params
    dist.broadcast(
  File "/nvme/xingshuhao.dispatch/anaconda3/envs/fairscale/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1400, in broadcast
    work = default_pg.broadcast([tensor], opts)
RuntimeError: Tensors must be CUDA and dense

What’s the right way to use OSS?

Would it be possible for you to try out PyTorch’s native FullyShardedDataParallel?
https://pytorch.org/docs/stable/fsdp.html

Fairscale FSDP is not being actively developed anymore.

Thanks, I have already tried pytorch’s fsdp but it doesn’t solve my problem…

I submit an issue: Memory usage different from deepspeed · Issue #1109 · facebookresearch/fairscale · GitHub, seems that I need to manually wrap my module since ModuleList cannot be wrapped automatically.