CPU memory in FSDP lora merging

I want to get the state dict of a FSDP model with lora merged. However, I don’t have enough gpu memory to summon full parameter of the FSDP model, and merged the model on CPU is too slow. So i decide to summon parts of the parameters of the FSDP model on gpu, and offload the merged state dict to CPU. However, this process consumes more CPU memory than expected, why?

The following is my code
main.py

import os

os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

import torch
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig

from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager
from get_merged_state_dict import get_merged_state_dict

from peft import LoraConfig, TaskType, get_peft_model
from torch.distributed import init_device_mesh

local_model_path = 'qwen2.5-32b'

def main():
    torch.distributed.init_process_group('nccl')
    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    if torch.distributed.is_initialized():
        torch.cuda.set_device(local_rank)
        
    device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',))

    config = AutoConfig.from_pretrained(local_model_path)
    init_context = get_init_weight_context_manager(use_meta_tensor=True)

    with init_context():
        model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
            local_model_path,
            config=config,
            torch_dtype=torch.bfloat16,
            attn_implementation='flash_attention_2'
        )
        
        model.enable_input_require_grads()
        lora_config = {
            'task_type': TaskType.CAUSAL_LM,
            'r': 32,
            'lora_alpha': 32,
            'target_modules': 'all-linear',
            'bias': "none"
        }
        model = get_peft_model(model, LoraConfig(**lora_config))
    
    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16,
                                        reduce_dtype=torch.float32,
                                        buffer_dtype=torch.float32)
    
    auto_wrap_policy = get_fsdp_wrap_policy(
        model,
        config={'min_num_params': 1e7},
        is_lora=True
    )
    
    fsdp_model = FSDP(module=model,
                      auto_wrap_policy=auto_wrap_policy,
                    param_init_fn=init_fn,
                    sharding_strategy=ShardingStrategy.FULL_SHARD,
                    mixed_precision=mixed_precision,
                    device_mesh=device_mesh,
                    sync_module_states=True,
                    device_id=torch.cuda.current_device(),
                    cpu_offload=None,
                    use_orig_params=False)
    torch.manual_seed(0)

    res = get_merged_state_dict(fsdp_model)
    print(res.keys())
    
    print('success')
    
if __name__ == '__main__':
    main()

get_merged_state_dict.py

from peft.tuners.lora.layer import LoraLayer
from peft.tuners.lora import LoraModel
from torch.distributed import breakpoint as dbp
from torch.distributed import get_rank
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch
from collections import OrderedDict
import gc
import tracemalloc

def get_state_dict_size(state_dict, format='MB'):
    def _get_size_in_bytes(tensor):
        return tensor.nelement() * tensor.element_size()
    
    total_size = 0
    for key, param in state_dict.items():
        if isinstance(param, torch.Tensor):
            total_size += _get_size_in_bytes(param)
    
    if format == 'B':
        return total_size
    elif format == 'KB':
        return total_size / 1024
    elif format == 'MB':
        return total_size / (1024 ** 2)
    elif format == 'GB':
        return total_size / (1024 ** 3)

def get_merged_state_dict(self: LoraModel):
    state_dict = {}
    RANK = get_rank()
    
    def _travserse(model, prefix):
        name = None
        for name, module in model.named_children():
            full_name = f"{prefix}.{name}" if prefix else name
            if isinstance(module, LoraLayer):
                std = get_lora_merged_state_dict(module)
                for n in std:
                    if RANK==0:
                        print(f"{full_name}.{n}", std[n].dtype)
                    state_dict[f"{full_name}.{n}"] = std[n]
                continue
            else:
                if isinstance(module, FSDP):
                    with FSDP.summon_full_params(module, False, writeback=False):
                        _travserse(module, full_name)
                else:
                    _travserse(module, full_name)
                
        if name is None:
            std = model.state_dict()
            for n in std:
                if RANK==0:
                    print(f"{prefix}.{n}", std[n].dtype)
                    print(get_state_dict_size(state_dict, format='GB'))

                state_dict[f"{prefix}.{n}"] = std[n].cpu()
    
    
    with FSDP.summon_full_params(self, False, writeback=False):
        _travserse(self, '')
    
    new_state_dict = OrderedDict(
        (
            k.replace('.base_layer.', '.')\
             .replace("_fsdp_wrapped_module.", "")\
             .replace("base_model.", "")\
             .replace(".model.", '.'), 
            v
        ) 
        for k, v in state_dict.items()
    )

    return new_state_dict

@torch.no_grad
def get_lora_merged_state_dict(model: LoraLayer):
    output_dict = {}
    stream = torch.cuda.Stream()
    
    with torch.cuda.stream(stream):
        with FSDP.summon_full_params(model, writeback=False):
            weight_shape = model.base_layer.weight.shape
            cpu_weight = torch.empty(weight_shape, device='cpu', pin_memory=True, dtype=torch.bfloat16)
            delta_weight = model.get_delta_weight('default')
            merged_weight = model.base_layer.weight.data + delta_weight
            
            cpu_weight.copy_(merged_weight, non_blocking=True)
            output_dict['weight'] = cpu_weight
            
            if model.base_layer.bias is not None:
                cpu_bias = torch.empty_like(model.base_layer.bias, device='cpu', pin_memory=True)
                bias = model.base_layer.bias.data
                if model.lora_bias['default']:
                    bias = bias + model.lora_B['default'].bias
                cpu_bias.copy_(bias, non_blocking=True)
                output_dict['bias'] = cpu_bias
        
    return output_dict

when I run torchrun --nproc-per-node 8 main.py
It will get OOM in a 800 CPU GB memory device.

The max size of state_dict is 64GB, and 8 process should consume ~512 GB. What are the additional costs of CPU mem.