How to wrap layers that have shared weights in FSDP

Hi all!

I am currently trying to wrap a model with a Transformer-like architecture in FSDP.
Currently, I am using the transformer_auto_wrap_policy with Block being the Module to wrap.

However, I would like to also wrap the embedding and lm_head layers. Below you find a pseudo-code example of what I am currently doing:

class MyModel():

	def __init__(self, n_blocks):

		self.token_embedding = nn.Embedding()
		self.blocks: nn.ModuleList[Block] = create_blocks(n_blocks)
		self.lm_head = Linear()

		# tie weights 
		self.lm_head.weight = self.token_embedding.weight

	def forward(self, x):

		x = self.token_embedding(x)

		for block in self.blocks:
			x = block(x)

		return self.lm_head(x)


# Wrapping model in FSDP

block_auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={Block},
)

model = MyModel(2)
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    auto_wrap_policy=block_auto_wrap_policy,
    use_orig_params=True,
    device_id=torch.cuda.current_device(),
)

From what I can tell this wrap policy does only consider self.blocks. What do I need to change to wrap the other two layers?

Thank you very much in advance!