FSDP2 issue with layer sharding

I am trying to get FSDP2 working for a sequence classification model from transformers. Here are two ways of doing it:

Model(
  (model): LlamaForSequenceClassification(
    (model): FSDPLlamaModel(
      (embed_tokens): FSDPEmbedding(32000, 1536, padding_idx=0)
      (layers): ModuleList(
        (0-15): 16 x FSDPLlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=1536, out_features=1536, bias=False)
            (k_proj): Linear(in_features=1536, out_features=768, bias=False)
            (v_proj): Linear(in_features=1536, out_features=768, bias=False)
            (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=1536, out_features=4096, bias=False)
            (up_proj): Linear(in_features=1536, out_features=4096, bias=False)
            (down_proj): Linear(in_features=4096, out_features=1536, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((1536,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((1536,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((1536,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (score): Linear(in_features=1536, out_features=1, bias=False)
  )
Model(
  (backbone): FSDPLlamaForSequenceClassification(
    (model): FSDPLlamaModel(
      (embed_tokens): FSDPEmbedding(32000, 1536, padding_idx=0)
      (layers): ModuleList(
        (0-15): 16 x FSDPLlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=1536, out_features=1536, bias=False)
            (k_proj): Linear(in_features=1536, out_features=768, bias=False)
            (v_proj): Linear(in_features=1536, out_features=768, bias=False)
            (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=1536, out_features=4096, bias=False)
            (up_proj): Linear(in_features=1536, out_features=4096, bias=False)
            (down_proj): Linear(in_features=4096, out_features=1536, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((1536,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((1536,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((1536,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (score): Linear(in_features=1536, out_features=1, bias=False)
  )

The main difference is wrapping also LlamaForSequenceClassification with fully_shard or not. In the first case, it seems to train fine, although I am also not sure all layers are properly trained. However, in the second case it does not train at all.

Is this expected? Trying to understand this a bit better which layers to shard and which to not shard.

Tried to roughly follow: torchtune/torchtune/training/_distributed.py at main · pytorch/torchtune · GitHub

Okay, I think it is rather a stupid issue with the optimizer being initialized before instead of after. Will reopen if not solved by this.