FSDP OOM when forwarding 7B model on 16k context length text

I cannot prevent the model from oom using FSDP. I have the knowledge that the 7B model should not OOM with 16k and will oom on 32k. How can I prevent the OOM here?

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2Attention, Qwen2MLP, Qwen2ForCausalLM
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification
from functools import partial
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

def get_fsdp_wrap_policy(module):
    if hasattr(module, "get_decoder"):
        transformer_layer_cls = module.get_decoder().layers[0].__class__
        wrap_policy = partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={transformer_layer_cls},
        )
        return wrap_policy

    if hasattr(module, "layers"):
        transformer_layer_cls = module.layers[0].__class__
        wrap_policy = partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={transformer_layer_cls},
        )
        return wrap_policy

    raise NotImplementedError(
        "This module is not a decoder-only transformer. " "Please implement the `fsdp_wrap_policy` method."
    )


if __name__ == "__main__":
    length = 16384
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(dist.get_rank())
    
    input_ids = torch.randint(0, 100, (1, length), device=torch.cuda.current_device())
    position_ids = list(range(length))
    position_ids = torch.tensor(position_ids, device=torch.cuda.current_device()).unsqueeze(0)
    module = AutoModelForCausalLM.from_pretrained(
        "Qwen2/Qwen2.5-7B-Instruct",
        attn_implementation="flash_attention_2",
        device_map="cuda",
        torch_dtype=torch.float32,
    )
    module.gradient_checkpointing_enable()
    model = FSDP(
        module,
        process_group=dist.group.WORLD,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        auto_wrap_policy=get_fsdp_wrap_policy(module),
        mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
        device_id="cuda",
        forward_prefetch=True,
        backward_prefetch="BACKWARD_PRE",
        use_orig_params=False,
    )

    # The following will cause OOM
    logits = model(input_ids, position_ids=position_ids).logits
    print(logits.shape, logits.dtype)

    dist.destroy_process_group()