FSDP failed to save model checkpoints

Awesome! Thanks @eqy ! The workaround trick mentioned in the issue indeed solved my problem.

To anyone who is also facing this problem: the cause of this problem is that when HuggingFace Trainer calls model.state_dict() during the saving step with a FSDP model, it triggers a _full_post_state_dict_hook that will be executed after the call. In this post hook it did a clone of the state dict, which caused the CUDA OOM error.

To solve this problem, the most elegant way is probability utilizing the FSDP.set_state_dict_type with offload_to_cpu=True (sorry I cannot plug in more links as a new user), or using the FSDP.state_dict_type context manager to set cpu offload when calling model.state_dict(). However, since I’m utilizing FSDP through HuggingFace Trainer and the FSDP API is pretty limited, a quick and dirty workaround (as suggested in the github issue) is to just change the line

state_dict[fqn] = state_dict[fqn].clone().detach()

to

state_dict[fqn] = state_dict[fqn].cpu().clone().detach()

in _full_post_state_dict_hook.

It is located in line 309 of torch/distributed/fsdp/_state_dict_utils.py
in torch=2.0.0, or line 2221 of torch/distributed/fsdp/fully_sharded_data_parallel.py
in torch=1.13.0.

3 Likes