Awesome! Thanks @eqy ! The workaround trick mentioned in the issue indeed solved my problem.
To anyone who is also facing this problem: the cause of this problem is that when HuggingFace Trainer calls model.state_dict()
during the saving step with a FSDP model, it triggers a _full_post_state_dict_hook
that will be executed after the call. In this post hook it did a clone of the state dict, which caused the CUDA OOM error.
To solve this problem, the most elegant way is probability utilizing the FSDP.set_state_dict_type
with offload_to_cpu=True
(sorry I cannot plug in more links as a new user), or using the FSDP.state_dict_type
context manager to set cpu offload when calling model.state_dict()
. However, since I’m utilizing FSDP through HuggingFace Trainer and the FSDP API is pretty limited, a quick and dirty workaround (as suggested in the github issue) is to just change the line
state_dict[fqn] = state_dict[fqn].clone().detach()
to
state_dict[fqn] = state_dict[fqn].cpu().clone().detach()
in _full_post_state_dict_hook
.
It is located in line 309 of torch/distributed/fsdp/_state_dict_utils.py
in torch=2.0.0
, or line 2221 of torch/distributed/fsdp/fully_sharded_data_parallel.py
in torch=1.13.0
.