Switching between FSDP and DDP

Thanks for all the help with and building FSDP. I’m able to successfully train my models with FSDP but there are some operations during my evaluation that don’t quite work with FSDP and I need to use DDP. I wanted to check if there is a way I can “unshard”/reset my model after training to run evaluation (assuming I am able to fit my model in memory without FSDP in no_grad() mode)

A general flow would look like:

model = FSDP(model,...)
train(model)
.
.
reset_model() # this would basically remove the FSDP wrapper
model = DDP(model,...)
evaluate(model)

Thanks!

We should figure out and document the canonical way to do this, but here is something that can hopefully unblock you for now:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig, StateDictType

def reset_model(fsdp_model: FSDP, model_ctor: Callable):
    with FSDP.state_dict_type(
        fsdp_model,
        StateDictType.FULL_STATE_DICT,
        FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
    ):
        # Since `rank0_only=True`, only rank 0 has a full state dict in memory
        state_dict = fsdp_model.state_dict()

    # Construct a nonwrapped model like what was passed into FSDP, where the
    # parameter/buffer initialization does not matter
    nonwrapped_model = model_ctor()  # on CPU
    if rank == 0:
       nonwrapped_model.load_state_dict(state_dict)
    # Constructing DDP will broadcast parameters/buffers from rank 0 to all ranks
    ddp_model = DDP(nonwrapped_model.to(rank), device_ids=[rank])

model_ctor() can initialize on GPU, in which case state_dict should be saved on GPU (with offload_to_cpu=False).

I wrote this off the top of my head and have not tested it. However, it conveys the general idea: You save a full state dict and load it into a nonwrapped version of the module before wrapping with DDP. DDP will broadcast module states in its constructor, so that saves some additional logic.