FSDP2 for GGUF / BNB models

I know this question sounds a bit crazy, but I need to shard a GGUF or BitsAndBytes model, or any dtype lower than FP8.

The problem is that when I try to load a distributed state_dict, the model either fails to load or the forward pass breaks due to mismatched GGML tensor shapes during matrix multiplication. For example, if I load a model like Wan 2.1 1.3B, its native dtype in one of its attention layer is 1536×1536, but when quantized to GGUF and loaded into non FSDP models, it becomes 1536×864.

This works fine for single-GPU inference, but not with FSDP2 when using distributed state_dict. I often get errors like “cannot copy meta tensor” due to shape mismatches between the model and the quantized weights.

Has anyone tried something similar? Or is this just a waste of time to pursue?

For reference:

  • PyTorch 2.8, CUDA 12.8, Python 3.11, running on WSL with an RTX 3090.

Thanks for your time