FSDP failed to save model checkpoints

Hi, I’m training a large LM on 8 A100-80GB GPUs using FSDP in HuggingFace’s Trainer. I specified the FSDP parameters as following:

fsdp: full_shard auto_wrap
fsdp_config:
  fsdp_transformer_layer_cls_to_wrap:
  - LlamaDecoderLayer

But when saving the model checkpoint, FSDP gives the following warning due to CUDA out of memory:

/home/ubuntu/miniconda3/envs/finetune-clm/lib/python3.10/site-
packages/torch/distributed/fsdp/_state_dict_utils.py:312: UserWarning: Failed to clone() tensor with
name _fsdp_wrapped_module.model.layers.59.mlp.gate_proj.weight on rank 6. This may mean that 
this state_dict entry could point to invalid memory regions after returning from state_dict() call if this 
parameter is managed by FSDP. Please check clone implementation of 
_fsdp_wrapped_module.model.layers.59.mlp.gate_proj.weight. Error: CUDA out of memory. Tried to 
allocate 228.00 MiB (GPU 6; 79.20 GiB total capacity; 75.32 GiB already allocated; 75.25 MiB free; 
77.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting 
max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and 
PYTORCH_CUDA_ALLOC_CONF

However, there are plenty of memory spaces left during training:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:05:00.0 Off |                    0 |
| N/A   50C    P0   250W / 400W |  45141MiB / 81920MiB |     99%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:06:00.0 Off |                    0 |
| N/A   51C    P0   220W / 400W |  47039MiB / 81920MiB |     99%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM...  On   | 00000000:07:00.0 Off |                    0 |
| N/A   48C    P0   328W / 400W |  52993MiB / 81920MiB |     99%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM...  On   | 00000000:08:00.0 Off |                    0 |
| N/A   52C    P0   413W / 400W |  54209MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM...  On   | 00000000:09:00.0 Off |                    0 |
| N/A   48C    P0    94W / 400W |  47175MiB / 81920MiB |     99%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   5  NVIDIA A100-SXM...  On   | 00000000:0A:00.0 Off |                    0 |
| N/A   43C    P0   379W / 400W |  53687MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   6  NVIDIA A100-SXM...  On   | 00000000:0B:00.0 Off |                    0 |
| N/A   43C    P0   158W / 400W |  49709MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   7  NVIDIA A100-SXM...  On   | 00000000:0C:00.0 Off |                    0 |
| N/A   59C    P0   421W / 400W |  54341MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

It is just a warning and the training process can continue. But the saved model checkpoints are in bad shape and cannot be loaded. That means I will not be able to resume from an intermediate checkpoints.

When I tested the training job on a smaller GPU using a smaller model, FSDP can save model checkpoints without any problem, even when the GPU memory was tighter (less than 1GB free memory during training). But on the 80GB A100 GPUs, it complains CUDA OOM although there are almost 30GB free memory left during training. Are there anything I can do to get rid of this problem? Or did I miss anything?

My training env:

  • torch: 2.0.0+cu118
  • cuda: 11.8
  • transformers: 4.48.1

Thanks in advance!

Could you check if the CPU offloading workaround as described here works for you?

1 Like

Awesome! Thanks @eqy ! The workaround trick mentioned in the issue indeed solved my problem.

To anyone who is also facing this problem: the cause of this problem is that when HuggingFace Trainer calls model.state_dict() during the saving step with a FSDP model, it triggers a _full_post_state_dict_hook that will be executed after the call. In this post hook it did a clone of the state dict, which caused the CUDA OOM error.

To solve this problem, the most elegant way is probability utilizing the FSDP.set_state_dict_type with offload_to_cpu=True (sorry I cannot plug in more links as a new user), or using the FSDP.state_dict_type context manager to set cpu offload when calling model.state_dict(). However, since I’m utilizing FSDP through HuggingFace Trainer and the FSDP API is pretty limited, a quick and dirty workaround (as suggested in the github issue) is to just change the line

state_dict[fqn] = state_dict[fqn].clone().detach()

to

state_dict[fqn] = state_dict[fqn].cpu().clone().detach()

in _full_post_state_dict_hook.

It is located in line 309 of torch/distributed/fsdp/_state_dict_utils.py
in torch=2.0.0, or line 2221 of torch/distributed/fsdp/fully_sharded_data_parallel.py
in torch=1.13.0.

3 Likes