Hi all
I’m trying to training T5-3b with FSDP…
with
t5_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
T5Block
},
)
T5-3b is wrapped as follows
FullyShardedDataParallel(
(_fsdp_wrapped_module): T5ForConditionalGeneration(
(shared): Embedding(32128, 1024)
(encoder): T5Stack(
(embed_tokens): Embedding(32128, 1024)
(block): ModuleList(
(0): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(in_features=1024, out_features=4096, bias=False)
(k): Linear(in_features=1024, out_features=4096, bias=False)
(v): Linear(in_features=1024, out_features=4096, bias=False)
(o): Linear(in_features=4096, out_features=1024, bias=False)
(relative_attention_bias): Embedding(32, 32)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseActDense(
(wi): Linear(in_features=1024, out_features=16384, bias=False)
(wo): Linear(in_features=16384, out_features=1024, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): ReLU()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(1-23): 23 x FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(in_features=1024, out_features=4096, bias=False)
(k): Linear(in_features=1024, out_features=4096, bias=False)
(v): Linear(in_features=1024, out_features=4096, bias=False)
(o): Linear(in_features=4096, out_features=1024, bias=False)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseActDense(
(wi): Linear(in_features=1024, out_features=16384, bias=False)
(wo): Linear(in_features=16384, out_features=1024, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): ReLU()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
)
(final_layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(decoder): T5Stack(
(embed_tokens): Embedding(32128, 1024)
(block): ModuleList(
(0): FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(in_features=1024, out_features=4096, bias=False)
(k): Linear(in_features=1024, out_features=4096, bias=False)
(v): Linear(in_features=1024, out_features=4096, bias=False)
(o): Linear(in_features=4096, out_features=1024, bias=False)
(relative_attention_bias): Embedding(32, 32)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear(in_features=1024, out_features=4096, bias=False)
(k): Linear(in_features=1024, out_features=4096, bias=False)
(v): Linear(in_features=1024, out_features=4096, bias=False)
(o): Linear(in_features=4096, out_features=1024, bias=False)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseActDense(
(wi): Linear(in_features=1024, out_features=16384, bias=False)
(wo): Linear(in_features=16384, out_features=1024, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): ReLU()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(1-23): 23 x FullyShardedDataParallel(
(_fsdp_wrapped_module): T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear(in_features=1024, out_features=4096, bias=False)
(k): Linear(in_features=1024, out_features=4096, bias=False)
(v): Linear(in_features=1024, out_features=4096, bias=False)
(o): Linear(in_features=4096, out_features=1024, bias=False)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear(in_features=1024, out_features=4096, bias=False)
(k): Linear(in_features=1024, out_features=4096, bias=False)
(v): Linear(in_features=1024, out_features=4096, bias=False)
(o): Linear(in_features=4096, out_features=1024, bias=False)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseActDense(
(wi): Linear(in_features=1024, out_features=16384, bias=False)
(wo): Linear(in_features=16384, out_features=1024, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): ReLU()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
)
(final_layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(lm_head): Linear(in_features=1024, out_features=32128, bias=False)
)
)
Each T5Block
is wrapped by single FSDP unit, however I want to wrap multiple T5Blocks in single unit…
Is there any solution for this in functools.partial(transformer_auto_wrap_policy)
?
Best regards
Taekyoung