Hi all!
I am currently trying to wrap a model with a Transformer-like architecture in FSDP.
Currently, I am using the transformer_auto_wrap_policy with Block being the Module to wrap.
However, I would like to also wrap the embedding and lm_head layers. Below you find a pseudo-code example of what I am currently doing:
class MyModel():
def __init__(self, n_blocks):
self.token_embedding = nn.Embedding()
self.blocks: nn.ModuleList[Block] = create_blocks(n_blocks)
self.lm_head = Linear()
# tie weights
self.lm_head.weight = self.token_embedding.weight
def forward(self, x):
x = self.token_embedding(x)
for block in self.blocks:
x = block(x)
return self.lm_head(x)
# Wrapping model in FSDP
block_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={Block},
)
model = MyModel(2)
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=block_auto_wrap_policy,
use_orig_params=True,
device_id=torch.cuda.current_device(),
)
From what I can tell this wrap policy does only consider self.blocks. What do I need to change to wrap the other two layers?
Thank you very much in advance!