Hi, I wanted to confirm my understanding regarding the current behavior of PyTorch FSDP (native) when it comes to setting separate optimization parameters for different layers/modules within an FSDP wrapped module.
As a concrete example, I am following this tutorial of running a HuggingFace T5 model with FSDP and I use the transformer_auto_wrap_policy
with T5Block
being the transformer_layer_cls
as suggested here. In this case, if I understand correctly, it is not possible to set weight-decay corresponding to the parameters of the LayerNorm layers present within each T5Block
to a value different from the weight-decay that’s passed to the optimizer? In particular, I’d like to have the weight-decay value to be 0.0 for these parameters following standard practice. Without FSDP, I can filter out such parameters by name or type and send two separate parameter groups to the Optimizer. With FSDP, each T5Block
is presented as a flat parameter which prevents us from doing so. While using FSDP, I can potentially wrap each individual LayerNorm to be its own FSDP unit but it will require modifying the T5Block
module and also will create a much higher number of FSDP units, thereby reducing compute/communication efficiency.
I have looked into this FSDP method called summon_full_params
but I am not sure if I can create such parameter groups using the unflattened parameters returned by this function and pass it to the torch.optim.AdamW optimizer. If not, then is there any other way to achieve this behavior with FSDP at this moment?