Hey all,
I’m hitting a weird precision issue trying to cross-check my FSDP2 run against a DDP run.
I’m training a Decoder model using flex_attention and torch.compile. When I compare the gradients between FSDP2 and DDP after one step (same weights/inputs), the Token Embeddings are significantly off (absolute diff around ~18.0), while the rest of the layers match up pretty well (~1.0 diff).
My setup:
-
FSDP2: Using
MixedPrecisionPolicy(param_dtype=bf16, reduce_dtype=fp32). -
DDP: Standard setup with
autocast(dtype=bf16). -
Code: In DDP, I’m manually doing
x = self.embedding(input).to(bf16)to pass into the transformer.
What I think is happening: Since FSDP2 has param_dtype=bf16, it seems like it’s materializing the weights in BF16 for the forward pass, which forces the backward pass (sparse accumulation) to happen in BF16. In DDP, since the source weights are still FP32, Autograd seems to be casting the gradients back to FP32 before accumulating, so I’m getting “too much” precision compared to FSDP.
The Question: Is there a clean/standard way to tell DDP (or Autocast) to force BF16 accumulation for the embedding layer to match FSDP’s behavior? Or is there a method of doing this that I am missing? OR this could be an artifact of DDP vs FSDP in general and if so how much diff should I be expecting
Thanks!