I am trying to get FSDP2 working for a sequence classification model from transformers. Here are two ways of doing it:
Model(
(model): LlamaForSequenceClassification(
(model): FSDPLlamaModel(
(embed_tokens): FSDPEmbedding(32000, 1536, padding_idx=0)
(layers): ModuleList(
(0-15): 16 x FSDPLlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=1536, out_features=1536, bias=False)
(k_proj): Linear(in_features=1536, out_features=768, bias=False)
(v_proj): Linear(in_features=1536, out_features=768, bias=False)
(o_proj): Linear(in_features=1536, out_features=1536, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=1536, out_features=4096, bias=False)
(up_proj): Linear(in_features=1536, out_features=4096, bias=False)
(down_proj): Linear(in_features=4096, out_features=1536, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((1536,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((1536,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((1536,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(score): Linear(in_features=1536, out_features=1, bias=False)
)
Model(
(backbone): FSDPLlamaForSequenceClassification(
(model): FSDPLlamaModel(
(embed_tokens): FSDPEmbedding(32000, 1536, padding_idx=0)
(layers): ModuleList(
(0-15): 16 x FSDPLlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=1536, out_features=1536, bias=False)
(k_proj): Linear(in_features=1536, out_features=768, bias=False)
(v_proj): Linear(in_features=1536, out_features=768, bias=False)
(o_proj): Linear(in_features=1536, out_features=1536, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=1536, out_features=4096, bias=False)
(up_proj): Linear(in_features=1536, out_features=4096, bias=False)
(down_proj): Linear(in_features=4096, out_features=1536, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((1536,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((1536,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((1536,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(score): Linear(in_features=1536, out_features=1, bias=False)
)
The main difference is wrapping also LlamaForSequenceClassification
with fully_shard
or not. In the first case, it seems to train fine, although I am also not sure all layers are properly trained. However, in the second case it does not train at all.
Is this expected? Trying to understand this a bit better which layers to shard and which to not shard.
Tried to roughly follow: torchtune/torchtune/training/_distributed.py at main · pytorch/torchtune · GitHub