FSDP with non-uniform 'requires_grad'

Hello, I’m currently trying to wrap my Model which contains some frozen params with nested FSDP. I read from the pytorch.distributed.fsdp this warning:

FSDP has some constraints on freezing parameters (i.e. setting param.requires_grad=False ). For use_orig_params=False , each FSDP instance must manage parameters that are all frozen or all non-frozen. For use_orig_params=True , FSDP supports mixing frozen and non-frozen, but we recommend not doing so since then the gradient memory usage will be higher than expected (namely, equivalent to not freezing those parameters). This means that ideally, frozen parameters should be isolated into their own nn.Module s and wrapped separately with FSDP.

I didn’t want to set use_orig_params=True and as they stated, I Isolated the frozen parameters into their own nn.Module s and then wrapped every submodule with FSDP, but I’m still getting this error:

ValueError: FlatParameter requires uniform requires_grad

Could someone please clarify what they meant by ‘ideally’ or help me overcome this issue?

Thanks!

1 Like

Hi, did you ever figure this out? I am trying to do something similar.

1 Like

Hi, sorry for the late reply. I overcame the issue by using a type-based auto-wrap policy, which lets you exclude the modules with frozen parameters in FSDP auto_wrapper_callable:

# Automatic wrapping sub-modules with inner FSDP
    auto_wrap_policy = None
    auto_wrapper_callable = None
    if FLAGS.auto_wrap_policy != "none":
        if FLAGS.auto_wrap_policy == "size_based":
            # auto-wrap all sub-modules with a minimum number of parameters (default 1e6)
            auto_wrap_policy = partial(
                size_based_auto_wrap_policy,
                min_num_params=int(float(FLAGS.auto_wrap_min_num_params)))
        elif FLAGS.auto_wrap_policy == "type_based":
            # auto-wrap all sub-modules in MBConv or FFN
            auto_wrap_policy = partial(
                transformer_auto_wrap_policy,
                transformer_layer_cls={
                    Backbone.MBConv, # you can specify the modules you want to include here
                    Backbone.FeedForward,
                })
        else:
            raise Exception(
                f"Invalid auto-wrap policy: {FLAGS.auto_wrap_policy}")
        if FLAGS.use_gradient_checkpointing:
            # Apply gradient checkpointing to auto-wrapped sub-modules if specified
            auto_wrapper_callable = lambda m, *args, **kwargs: FSDP(
                checkpoint_module(m), *args, **kwargs)

    def fsdp_wrap(m): return FSDP(
        m,
        compute_dtype=getattr(torch, FLAGS.compute_dtype),
        fp32_reduce_scatter=FLAGS.fp32_reduce_scatter,
        flatten_parameters=FLAGS.flatten_parameters,
        pin_layout_in_collective_ops=FLAGS.pin_layout_in_collective_ops,
        auto_wrap_policy=auto_wrap_policy,
        auto_wrapper_callable=auto_wrapper_callable)

I don’t know if this helps in your case.