I want to get the state dict of a FSDP model with lora merged. However, I don’t have enough gpu memory to summon full parameter of the FSDP model, and merged the model on CPU is too slow. So i decide to summon parts of the parameters of the FSDP model on gpu, and offload the merged state dict to CPU. However, this process consumes more CPU memory than expected, why?
The following is my code
main.py
import os
os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
import torch
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
from get_merged_state_dict import get_merged_state_dict
from peft import LoraConfig, TaskType, get_peft_model
from torch.distributed import init_device_mesh
local_model_path = 'qwen2.5-32b'
def main():
torch.distributed.init_process_group('nccl')
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)
device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',))
config = AutoConfig.from_pretrained(local_model_path)
init_context = get_init_weight_context_manager(use_meta_tensor=True)
with init_context():
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
local_model_path,
config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2'
)
model.enable_input_require_grads()
lora_config = {
'task_type': TaskType.CAUSAL_LM,
'r': 32,
'lora_alpha': 32,
'target_modules': 'all-linear',
'bias': "none"
}
model = get_peft_model(model, LoraConfig(**lora_config))
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32)
auto_wrap_policy = get_fsdp_wrap_policy(
model,
config={'min_num_params': 1e7},
is_lora=True
)
fsdp_model = FSDP(module=model,
auto_wrap_policy=auto_wrap_policy,
param_init_fn=init_fn,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=device_mesh,
sync_module_states=True,
device_id=torch.cuda.current_device(),
cpu_offload=None,
use_orig_params=False)
torch.manual_seed(0)
res = get_merged_state_dict(fsdp_model)
print(res.keys())
print('success')
if __name__ == '__main__':
main()
get_merged_state_dict.py
from peft.tuners.lora.layer import LoraLayer
from peft.tuners.lora import LoraModel
from torch.distributed import breakpoint as dbp
from torch.distributed import get_rank
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch
from collections import OrderedDict
import gc
import tracemalloc
def get_state_dict_size(state_dict, format='MB'):
def _get_size_in_bytes(tensor):
return tensor.nelement() * tensor.element_size()
total_size = 0
for key, param in state_dict.items():
if isinstance(param, torch.Tensor):
total_size += _get_size_in_bytes(param)
if format == 'B':
return total_size
elif format == 'KB':
return total_size / 1024
elif format == 'MB':
return total_size / (1024 ** 2)
elif format == 'GB':
return total_size / (1024 ** 3)
def get_merged_state_dict(self: LoraModel):
state_dict = {}
RANK = get_rank()
def _travserse(model, prefix):
name = None
for name, module in model.named_children():
full_name = f"{prefix}.{name}" if prefix else name
if isinstance(module, LoraLayer):
std = get_lora_merged_state_dict(module)
for n in std:
if RANK==0:
print(f"{full_name}.{n}", std[n].dtype)
state_dict[f"{full_name}.{n}"] = std[n]
continue
else:
if isinstance(module, FSDP):
with FSDP.summon_full_params(module, False, writeback=False):
_travserse(module, full_name)
else:
_travserse(module, full_name)
if name is None:
std = model.state_dict()
for n in std:
if RANK==0:
print(f"{prefix}.{n}", std[n].dtype)
print(get_state_dict_size(state_dict, format='GB'))
state_dict[f"{prefix}.{n}"] = std[n].cpu()
with FSDP.summon_full_params(self, False, writeback=False):
_travserse(self, '')
new_state_dict = OrderedDict(
(
k.replace('.base_layer.', '.')\
.replace("_fsdp_wrapped_module.", "")\
.replace("base_model.", "")\
.replace(".model.", '.'),
v
)
for k, v in state_dict.items()
)
return new_state_dict
@torch.no_grad
def get_lora_merged_state_dict(model: LoraLayer):
output_dict = {}
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
with FSDP.summon_full_params(model, writeback=False):
weight_shape = model.base_layer.weight.shape
cpu_weight = torch.empty(weight_shape, device='cpu', pin_memory=True, dtype=torch.bfloat16)
delta_weight = model.get_delta_weight('default')
merged_weight = model.base_layer.weight.data + delta_weight
cpu_weight.copy_(merged_weight, non_blocking=True)
output_dict['weight'] = cpu_weight
if model.base_layer.bias is not None:
cpu_bias = torch.empty_like(model.base_layer.bias, device='cpu', pin_memory=True)
bias = model.base_layer.bias.data
if model.lora_bias['default']:
bias = bias + model.lora_B['default'].bias
cpu_bias.copy_(bias, non_blocking=True)
output_dict['bias'] = cpu_bias
return output_dict
when I run torchrun --nproc-per-node 8 main.py
It will get OOM in a 800 CPU GB memory device.
The max size of state_dict is 64GB, and 8 process should consume ~512 GB. What are the additional costs of CPU mem.