Setting different weight-decay values for parameters within one FSDP unit

Hi, I wanted to confirm my understanding regarding the current behavior of PyTorch FSDP (native) when it comes to setting separate optimization parameters for different layers/modules within an FSDP wrapped module.

As a concrete example, I am following this tutorial of running a HuggingFace T5 model with FSDP and I use the transformer_auto_wrap_policy with T5Block being the transformer_layer_cls as suggested here. In this case, if I understand correctly, it is not possible to set weight-decay corresponding to the parameters of the LayerNorm layers present within each T5Block to a value different from the weight-decay that’s passed to the optimizer? In particular, I’d like to have the weight-decay value to be 0.0 for these parameters following standard practice. Without FSDP, I can filter out such parameters by name or type and send two separate parameter groups to the Optimizer. With FSDP, each T5Block is presented as a flat parameter which prevents us from doing so. While using FSDP, I can potentially wrap each individual LayerNorm to be its own FSDP unit but it will require modifying the T5Block module and also will create a much higher number of FSDP units, thereby reducing compute/communication efficiency.

I have looked into this FSDP method called summon_full_params but I am not sure if I can create such parameter groups using the unflattened parameters returned by this function and pass it to the torch.optim.AdamW optimizer. If not, then is there any other way to achieve this behavior with FSDP at this moment?


You should be able to enable this behavior with use_original_parameters=True which is required to be set when multiple parameter groups are used.

Great, thanks a lot - it works with PyTorch nightly.