How can I wrap multiple transformer blocks in T5?

Hi all

I’m trying to training T5-3b with FSDP…

with

    t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block
        },
    )

T5-3b is wrapped as follows

FullyShardedDataParallel(
  (_fsdp_wrapped_module): T5ForConditionalGeneration(
    (shared): Embedding(32128, 1024)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 1024)
      (block): ModuleList(
        (0): FullyShardedDataParallel(
          (_fsdp_wrapped_module): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(in_features=1024, out_features=4096, bias=False)
                  (k): Linear(in_features=1024, out_features=4096, bias=False)
                  (v): Linear(in_features=1024, out_features=4096, bias=False)
                  (o): Linear(in_features=4096, out_features=1024, bias=False)
                  (relative_attention_bias): Embedding(32, 32)
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (1): T5LayerFF(
                (DenseReluDense): T5DenseActDense(
                  (wi): Linear(in_features=1024, out_features=16384, bias=False)
                  (wo): Linear(in_features=16384, out_features=1024, bias=False)
                  (dropout): Dropout(p=0.1, inplace=False)
                  (act): ReLU()
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
        )
        (1-23): 23 x FullyShardedDataParallel(
          (_fsdp_wrapped_module): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(in_features=1024, out_features=4096, bias=False)
                  (k): Linear(in_features=1024, out_features=4096, bias=False)
                  (v): Linear(in_features=1024, out_features=4096, bias=False)
                  (o): Linear(in_features=4096, out_features=1024, bias=False)
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (1): T5LayerFF(
                (DenseReluDense): T5DenseActDense(
                  (wi): Linear(in_features=1024, out_features=16384, bias=False)
                  (wo): Linear(in_features=16384, out_features=1024, bias=False)
                  (dropout): Dropout(p=0.1, inplace=False)
                  (act): ReLU()
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
        )
      )
      (final_layer_norm): T5LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (decoder): T5Stack(
      (embed_tokens): Embedding(32128, 1024)
      (block): ModuleList(
        (0): FullyShardedDataParallel(
          (_fsdp_wrapped_module): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(in_features=1024, out_features=4096, bias=False)
                  (k): Linear(in_features=1024, out_features=4096, bias=False)
                  (v): Linear(in_features=1024, out_features=4096, bias=False)
                  (o): Linear(in_features=4096, out_features=1024, bias=False)
                  (relative_attention_bias): Embedding(32, 32)
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (1): T5LayerCrossAttention(
                (EncDecAttention): T5Attention(
                  (q): Linear(in_features=1024, out_features=4096, bias=False)
                  (k): Linear(in_features=1024, out_features=4096, bias=False)
                  (v): Linear(in_features=1024, out_features=4096, bias=False)
                  (o): Linear(in_features=4096, out_features=1024, bias=False)
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (2): T5LayerFF(
                (DenseReluDense): T5DenseActDense(
                  (wi): Linear(in_features=1024, out_features=16384, bias=False)
                  (wo): Linear(in_features=16384, out_features=1024, bias=False)
                  (dropout): Dropout(p=0.1, inplace=False)
                  (act): ReLU()
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
        )
        (1-23): 23 x FullyShardedDataParallel(
          (_fsdp_wrapped_module): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(in_features=1024, out_features=4096, bias=False)
                  (k): Linear(in_features=1024, out_features=4096, bias=False)
                  (v): Linear(in_features=1024, out_features=4096, bias=False)
                  (o): Linear(in_features=4096, out_features=1024, bias=False)
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (1): T5LayerCrossAttention(
                (EncDecAttention): T5Attention(
                  (q): Linear(in_features=1024, out_features=4096, bias=False)
                  (k): Linear(in_features=1024, out_features=4096, bias=False)
                  (v): Linear(in_features=1024, out_features=4096, bias=False)
                  (o): Linear(in_features=4096, out_features=1024, bias=False)
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (2): T5LayerFF(
                (DenseReluDense): T5DenseActDense(
                  (wi): Linear(in_features=1024, out_features=16384, bias=False)
                  (wo): Linear(in_features=16384, out_features=1024, bias=False)
                  (dropout): Dropout(p=0.1, inplace=False)
                  (act): ReLU()
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
        )
      )
      (final_layer_norm): T5LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (lm_head): Linear(in_features=1024, out_features=32128, bias=False)
  )
)

Each T5Block is wrapped by single FSDP unit, however I want to wrap multiple T5Blocks in single unit…

Is there any solution for this in functools.partial(transformer_auto_wrap_policy)?

Best regards
Taekyoung

Answered on Github: Wrapping multiple layers in Pytorch FSDP · Issue #116986 · pytorch/pytorch · GitHub