FSDP2 backward issue

Hi experts,

I want to train models with FSDP2, but I found mismatch of the updated parameters compared to none-FSDP model. Here is an example,

torch.manual_seed(1)
model = nn.Sequential(
    nn.Sequential(nn.Linear(128, 128), nn.Linear(128, 128)),
    nn.Sequential(nn.Linear(128, 128), nn.Linear(128, 128)),
    nn.Linear(128, 128),
).cuda()
inputs = torch.randn(2, 2, 128).cuda()

# fsdp copy
fsdp_model = copy.deepcopy(model)

out_ref = model(inputs)
optimizer_ref = torch.optim.SGD(model.parameters(), lr=0.1)

# apply FSDP
fully_shard(fsdp_model[0])
fully_shard(fsdp_model[1])
fully_shard(fsdp_model)

out_test = fsdp_model(inputs)
optimizer_test = torch.optim.SGD(fsdp_model.parameters(), lr=0.1)

# This could pass
assert torch.allclose(out_ref, out_test)

out_ref.sum().backward()
out_test.sum().backward()

# This could pass, all grads are identical
for (n1, p1), (n2, p2) in zip(model.named_parameters(), fsdp_model.named_parameters()):
    assert torch.allclose(p1, p2.full_tensor())
    assert torch.allclose(p1.grad, p2.grad.full_tensor())

optimizer_ref.step()
optimizer_ref.zero_grad()
optimizer_test.step()
optimizer_test.zero_grad()


# This would fail, p1 != p2 for the last Linear layer in the model
for (n1, p1), (n2, p2) in zip(model.named_parameters(), fsdp_model.named_parameters()):
    assert torch.allclose(p1, p2.full_tensor())

out_ref_1 = model(inputss)
out_test_1 = fsdp_model(inputss)
# This would fail because the last layer is different
assert torch.allclose(out_ref_1, out_test_1)

torch.multiprocessing is used to run the code. I expect the weights of FSDP model should be identical to the non-FSDP model’s after the optimizer step, but the last linear layer is different, I can’t understand since the gradients are identical.

If I fully_shard the last linear layer, the test could pass, but like LLM models, there’s a linear LM head outside decoders. I checked the torchtitan, fully_shard every submodule is not expected.
Any help is appreciated, thank you!

thanks for sharing the repro. it’s very helpful!

Is there a reason to init optimizer after 1st forward? Usually we init optimizer before 1st forward

I modifed the repo to init optimizer before 1st forward and numerics became the same:

# torchrun --standalone --nproc_per_node=2 test_fsdp2_numerics.py

import torch
import torch.nn as nn
import copy
import os
import torch
import torch.distributed as dist
from torch.distributed._composable.fsdp.fully_shard import fully_shard

def main():
    dist.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    device = f"cuda:{gpu_id}"
    torch.cuda.set_device(device)
    torch.manual_seed(1)
    dim = 4
    model = nn.Sequential(
        nn.Sequential(nn.Linear(dim, dim), nn.Linear(dim, dim)),
        nn.Sequential(nn.Linear(dim, dim), nn.Linear(dim, dim)),
        nn.Linear(dim, dim),
    ).cuda()
    inputs = torch.randn(2, 2, dim).cuda()

    # fsdp copy
    fsdp_model = copy.deepcopy(model)
    optimizer_ref = torch.optim.SGD(model.parameters(), lr=0.1)
    out_ref = model(inputs)

    # apply FSDP
    fully_shard(fsdp_model[0])
    fully_shard(fsdp_model[1])
    fully_shard(fsdp_model)
    optimizer_test = torch.optim.SGD(fsdp_model.parameters(), lr=0.1)
    out_test = fsdp_model(inputs)

    # This could pass
    assert torch.allclose(out_ref, out_test)

    out_ref.sum().backward()
    out_test.sum().backward()

    # This could pass, all grads are identical
    for (n1, p1), (n2, p2) in zip(model.named_parameters(), fsdp_model.named_parameters()):
        assert torch.equal(p1, p2.full_tensor())
        assert torch.equal(p1.grad, p2.grad.full_tensor())

    optimizer_ref.step()
    # optimizer_ref.zero_grad()
    optimizer_test.step()
    # optimizer_test.zero_grad()


    # This would fail, p1 != p2 for the last Linear layer in the model
    for (n1, p1), (n2, p2) in zip(model.named_parameters(), fsdp_model.named_parameters()):
        assert torch.equal(p1, p2.full_tensor()), f"{p1=} {p2.full_tensor()=}"

    out_ref_1 = model(inputs)
    out_test_1 = fsdp_model(inputs)
    # This would fail because the last layer is different
    assert torch.equal(out_ref_1, out_test_1)


if __name__ == "__main__":
    main()