How to use FSDP with meta device for pre-trained models?

I’m figuring out how to use FSDP with meta device, but I couldn’t find any documentation/examples on this except for this one:

param_init_fn (Optional[Callable[[nn.Module], None]]) –
A Callable[torch.nn.Module] -> None that specifies how modules that are currently on the meta device should be initialized onto an actual device. As of v1.12, FSDP detects modules with parameters or buffers on meta device via is_meta and either applies param_init_fn if specified or calls nn.Module.reset_parameters() otherwise. For both cases, the implementation should only initialize the parameters/buffers of the module, not those of its submodules. This is to avoid re-initialization. In addition, FSDP also supports deferred initialization via torchdistX’s (https://github.com/pytorch/torchdistX) deferred_init() API, where the deferred modules are initialized by calling param_init_fn if specified or torchdistX’s default materialize_module() otherwise. If param_init_fn is specified, then it is applied to all meta-device modules, meaning that it should probably case on the module type. FSDP calls the initialization function before parameter flattening and sharding.
Example:
>>> module = MyModule(device="meta")
>>> def my_init_fn(module: nn.Module):
>>>     # E.g. initialize depending on the module type
>>>     ...
>>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
>>> print(next(fsdp_model.parameters()).device) # current CUDA device
>>> # With torchdistX
>>> module = deferred_init.deferred_init(MyModule, device="cuda")
>>> # Will initialize via deferred_init.materialize_module().
>>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)

From what I understand it seems like the default way to initialize a model with FSDP + meta device is random parameters, which doesn’t work for pre-trained models. And the proposed alternative is to use torchdistX, however the torchdistX repository seems to be dead with the latest commit 9 months ago.

I would appreciate any guidance on how to do FSDP + meta device + pre-trained model

There are a couple of ways to do it.

If you use FSDP’s sync_module_states parameter (FullyShardedDataParallel — PyTorch 2.1 documentation) you can load the model weights onto rank0, and FSDP will shard the weights out to the remaining ranks.

E.g. in IBM’s FMS repo we first load the state dict only to rank0: https://github.com/foundation-model-stack/foundation-model-stack/blob/main/fms/models/__init__.py#L298 and then we FSDP wrap with sync_module_states: https://github.com/foundation-model-stack/foundation-model-stack/blob/main/fms/models/__init__.py#L186C9-L186C27. The tensors on remaining ranks just get moved to device with to_empty, and then the weights get copied into them from rank0.

In this case, one rank is non-meta tensors and holds the full state_dict.

The above example is intended for cases where the state-dict isn’t sharded. Another option is you can load FSDP-sharded checkpoints into the materialized tensors, after to_empty. In that case, the init_fn moves your meta tensors to devices, and then you can load the state dicts onto those.