Sub-modules in FSDP

I am new to FSDP.

This is my understanding of the sub-module FSDP.
FSDP relies on torch.nn.Module.forward() because of all-gather parameters. So, when I create sub modules with torch.nn.Sequential() in torch.nn.Module class, it does not work.
Also, FSDP class says

FSDP does not support running the forward pass of a submodule
that is contained in an FSDP instance. This is because the
submodule’s parameters will be sharded, but the submodule itself
is not an FSDP instance, so its forward pass will not all-gather
the full parameters appropriately.

For example, the FSDP tutorial (Pytorch) demonstrates Deep-ViT, Cross-ViT, and Cait on FSDP.
This tutorial wrapped them with transformer_auto_wrapper_policy, but it does not work because of sub-modules.
(transformer_central/transformer_wrapping_tutorial/transformer_wrapper_tutorial.ipynb at main · lessw2020/transformer_central · GitHub)

class DeepViT(nn.Module):
  def __init__(self, *, image_size, ...):
          super().__init__()
          self.to_patch_embedding = nn.Sequential(
                      Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
                      nn.Linear(patch_dim, dim),
                  )  <----- I think this does not work because of the sub-module.

In this case, we need to rewrite the sub-module (self.to_patch_embedding) with torch.nn.module like the following. (Assume add the following lines in Deep_ViT)

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))

    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += self.positions
        return x

class DeepViT(nn.Module):
      def __init__(self, *, image_size, ...):
          super().__init__()
          self.to_patch_embedding = PatchEmbedding()
          ...

     def forward(self, img):
          x = self.to_patch_embedding(img)
          x = self.dropout(x)
          x = self.transformer(x)
          x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
          x = self.to_latent(x)
          return self.mlp_head(x)

Then, wrap with transformer_auto_wrap_policy.

from deep_vit import Residual, PatchEmbedding
import functools

transformer_auto_wrapper_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            PatchEmbedding,
            Residual,  # < ---- Your Transformer layer class
        },
    )

sharded_model = FSDP(
        model,
        auto_wrap_policy=transformer_auto_wrapper_policy,
        mixed_precision=mp_policy,
        sharding_strategy=model_sharding_strategy,
        device_id=torch.cuda.current_device(),  # streaming init
    )

Do the above steps collect?
If so, FSDP shards PatchEmbedding across GPUs, right? (I’m not sure whether this is the right thing or not.)

cc. FSDP1 experts @weifengpy @agu

I am not sure if I followed completely. For self.to_patch_embedding, if you wrap it with FSDP and run self.to_patch_embedding(...) (its forward), then it should work.

I created a new account as I lost the previous password.

When I wrap (transformer wrap) the self.to_patch_embedding following the tutorial, I get this error:

RuntimeError: Output 0 of ViewBackward0 is a view and its base or another view of its base has been modified inplace. This view
is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace.
You should replace the inplace operation by an out-of-place one.

The below codes are coming from the link (transformer_central/transformer_wrapping_tutorial/transformer_wrapper_tutorial.ipynb at main · lessw2020/transformer_central · GitHub)

class DeepViT(nn.Module):
  def __init__(self, *, image_size, ...):
          super().__init__()
          self.to_patch_embedding = nn.Sequential(
                      Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
                      nn.Linear(patch_dim, dim),
                  )

Then,

from deep_vit import Residual
import functools

transformer_auto_wrapper_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            Residual,  # < ---- Your Transformer layer class
        },
    )

sharded_model = FSDP(
        model,
        auto_wrap_policy=transformer_auto_wrapper_policy,
        mixed_precision=mp_policy,
        sharding_strategy=model_sharding_strategy,
        device_id=torch.cuda.current_device(),  # streaming init
    )

So, I think somehow FSDP cannot wrap sub-modules :thinking:

I am still not following exactly. How did you update the auto_wrap_policy? Did you include nn.Sequential as one of the classes?

I ran into the same error. I was able to resolve it by defining a new empty subclass of nn.Sequential e.g. MySequential, instantiating the submodule using that new subclass, and adding an isinstance(module, MySequential) check to my existing custom auto wrap policy such that the new class is wrapped (I did not use the transformer_auto_wrapper_policy).

I feel the PyTorch FSDP documentation is misleading when it says: “FSDP does not support running the forward pass of a submodule that is contained in an FSDP instance.” Apparently, this sentence meant to say: “FSDP does not support running the forward pass of a submodule that is contained in an FSDP instance, unless the submodule is itself wrapped by FSDP.”