Qlora+fsdp2 training

Hi guys I am trying to train a model using qlora+fsdp2 on kaggle 2x tpus (the unlsoth assignment, problem B) but I keep running into the error
RuntimeError: only Tensors of floating point dtype can require gradients

I understand this error is because model is loaded using BitsAndBytesConfig which converts the dtypes from a normal float to int and hence the error.

When sharding this model using this code:

model_name = “unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit"
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map = {“”: local_rank},
attn_implementation = “sdpa”,
)
lora_config = LoraConfig(
r = 64,
lora_alpha = 128,
target_modules = [“q_proj”, “k_proj”, “v_proj”, “o_proj”,
“gate_proj”, “up_proj”, “down_proj”],
lora_dropout = 0,
bias = “none”,
task_type = TaskType.CAUSAL_LM,
)

# Get LoRA and setup model
model = get_peft_model(model, lora_config)

model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={“use_reentrant”: False}
)
model.enable_input_require_grads()
base_model = model.get_base_model()

# FSDP sharding submodules and the root model
mp_policy = MixedPrecisionPolicy(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    output_dtype=torch.bfloat16,
)

for layer in base_model.model.layers:
    fully_shard(layer, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)

I keep running into the above error. I tried gpt and claude a bit but they mention to dequantize the model to bf16, but it then throws OOM error.
Is there any workaround?