I am new to FSDP.
This is my understanding of the sub-module FSDP.
FSDP relies on torch.nn.Module.forward()
because of all-gather parameters. So, when I create sub modules with torch.nn.Sequential()
in torch.nn.Module
class, it does not work.
Also, FSDP class says
FSDP does not support running the forward pass of a submodule
that is contained in an FSDP instance. This is because the
submodule’s parameters will be sharded, but the submodule itself
is not an FSDP instance, so its forward pass will not all-gather
the full parameters appropriately.
For example, the FSDP tutorial (Pytorch) demonstrates Deep-ViT, Cross-ViT, and Cait on FSDP.
This tutorial wrapped them with transformer_auto_wrapper_policy
, but it does not work because of sub-modules.
(transformer_central/transformer_wrapping_tutorial/transformer_wrapper_tutorial.ipynb at main · lessw2020/transformer_central · GitHub)
class DeepViT(nn.Module):
def __init__(self, *, image_size, ...):
super().__init__()
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
) <----- I think this does not work because of the sub-module.
In this case, we need to rewrite the sub-module (self.to_patch_embedding) with torch.nn.module
like the following. (Assume add the following lines in Deep_ViT)
class PatchEmbedding(nn.Module):
def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(
# using a conv layer instead of a linear one -> performance gains
nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))
def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.projection(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
# prepend the cls token to the input
x = torch.cat([cls_tokens, x], dim=1)
# add position embedding
x += self.positions
return x
class DeepViT(nn.Module):
def __init__(self, *, image_size, ...):
super().__init__()
self.to_patch_embedding = PatchEmbedding()
...
def forward(self, img):
x = self.to_patch_embedding(img)
x = self.dropout(x)
x = self.transformer(x)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
Then, wrap with transformer_auto_wrap_policy.
from deep_vit import Residual, PatchEmbedding
import functools
transformer_auto_wrapper_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
PatchEmbedding,
Residual, # < ---- Your Transformer layer class
},
)
sharded_model = FSDP(
model,
auto_wrap_policy=transformer_auto_wrapper_policy,
mixed_precision=mp_policy,
sharding_strategy=model_sharding_strategy,
device_id=torch.cuda.current_device(), # streaming init
)
Do the above steps collect?
If so, FSDP shards PatchEmbedding across GPUs, right? (I’m not sure whether this is the right thing or not.)