I cannot prevent the model from oom using FSDP. I have the knowledge that the 7B model should not OOM with 16k and will oom on 32k. How can I prevent the OOM here?
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2Attention, Qwen2MLP, Qwen2ForCausalLM
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification
from functools import partial
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
def get_fsdp_wrap_policy(module):
if hasattr(module, "get_decoder"):
transformer_layer_cls = module.get_decoder().layers[0].__class__
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={transformer_layer_cls},
)
return wrap_policy
if hasattr(module, "layers"):
transformer_layer_cls = module.layers[0].__class__
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={transformer_layer_cls},
)
return wrap_policy
raise NotImplementedError(
"This module is not a decoder-only transformer. " "Please implement the `fsdp_wrap_policy` method."
)
if __name__ == "__main__":
length = 16384
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
input_ids = torch.randint(0, 100, (1, length), device=torch.cuda.current_device())
position_ids = list(range(length))
position_ids = torch.tensor(position_ids, device=torch.cuda.current_device()).unsqueeze(0)
module = AutoModelForCausalLM.from_pretrained(
"Qwen2/Qwen2.5-7B-Instruct",
attn_implementation="flash_attention_2",
device_map="cuda",
torch_dtype=torch.float32,
)
module.gradient_checkpointing_enable()
model = FSDP(
module,
process_group=dist.group.WORLD,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=get_fsdp_wrap_policy(module),
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
device_id="cuda",
forward_prefetch=True,
backward_prefetch="BACKWARD_PRE",
use_orig_params=False,
)
# The following will cause OOM
logits = model(input_ids, position_ids=position_ids).logits
print(logits.shape, logits.dtype)
dist.destroy_process_group()