Hi, I wanted to confirm my understanding regarding the current behavior of PyTorch FSDP (native) when it comes to setting separate optimization parameters for different layers/modules within an FSDP wrapped module.
As a concrete example, I am following this tutorial of running a HuggingFace T5 model with FSDP and I use the
T5Block being the
transformer_layer_cls as suggested here. In this case, if I understand correctly, it is not possible to set weight-decay corresponding to the parameters of the LayerNorm layers present within each
T5Block to a value different from the weight-decay that’s passed to the optimizer? In particular, I’d like to have the weight-decay value to be 0.0 for these parameters following standard practice. Without FSDP, I can filter out such parameters by name or type and send two separate parameter groups to the Optimizer. With FSDP, each
T5Block is presented as a flat parameter which prevents us from doing so. While using FSDP, I can potentially wrap each individual LayerNorm to be its own FSDP unit but it will require modifying the
T5Block module and also will create a much higher number of FSDP units, thereby reducing compute/communication efficiency.
I have looked into this FSDP method called
summon_full_params but I am not sure if I can create such parameter groups using the unflattened parameters returned by this function and pass it to the torch.optim.AdamW optimizer. If not, then is there any other way to achieve this behavior with FSDP at this moment?