CUDA OOM error when loading sharded checkpoint

Hi all,

We fine-tuned Stability’s StableLM-7b using Huggingface’s Trainer API (with FSDP) and then saved the resulting checkpoints in the sharded format that is typical for large language models. Quite surprisingly, however, attempting to load the model for inference leads to a strange error when loading one of the checkpoints (Unable to load weights from pytorch checkpoint file)

We took some further investigative steps by making a simple torch.load call on the problem shard, and got a typical CUDA OOM error. The exceedingly strange thing about this OOM error is that we are working with a node with 8xA100s (80GB), and the given state dict is only 171kB (comprising only 7 layers of the model). So, you can imagine seeing the following error was quite a shock:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 29.31 GiB (GPU 0; 79.19 GiB total capacity; 55.76 GiB already allocated; 22.48 GiB free; 55.76 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

After looking into this further, I discovered a few threads discussing this issue, like this one, and attempted some of the fixes, namely loading the state dict on CPU first. After doing so, I received the following error:
RuntimeError: Trying to resize storage that is not resizable

So it seems that approach is out of the question. As I previously said, the strange thing here is that the first two shards load without issue, while the third and fourth cannot be loaded. Additionally, nothing seems particularly out of place in the shard-layer mapping JSON. I am stumped here. Perhaps it’s an error in Huggingface Trainer’s save function, but I figured I would start at the PyTorch source first.

1 Like

It is also worth adding that the given device in the error (cuda:0) is completely empty when torch.load is called

@irisz I was wondering if you could you take a look. Thanks!

Could you share the code snippet for saving and loading?

If you are using HF trainer, the issue is they are running torch.save, which is not really suitable for larger model/multi-node scenarios for PyTorch FSDP.
For really large models, should be using sharded_state_dict and the distributed FileSystemWriter which is optimized for handling multi-node, large model scenarios.
Example distributed checkpoint save code here: pytorch/fsdp_checkpoint_example.py at e71ab214226af1f9dbded944e939c6447e0e8f09 · pytorch/pytorch · GitHub

I have a related question, similarly I am training a 7B model using accelerate and FSDP with StateDictType.SHARDED_STATE_DICT.

What is the recommended way to load sharded __{i}_{i}.distcp optimizer and parameter state dict files on a cpu or a single gpu without needing to initialize torch distributed?

For training and resuming training distributed checkpoint files make sense but for inference there might not be a need to use FSDP and can we convert these files to something that can be read simply with torch.load?

Following approach worked for me when saving full state dicts for both optimizer and the model.

# Save FSDP full optimizer state dict. 
full_optim_save_path = os.path.join(output_dir, "optimizer.bin")
optim_state = model.full_optim_state_dict(model, optimizer)
if accelerator.is_local_main_process:
    torch.save(optim_state, full_optim_save_path)
# Save FSDP full model state dict. 
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with model.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
    model_state = model.state_dict()
    full_model_save_path = os.path.join(output_dir, "pytorch_model.bin")
    if accelerator.is_local_main_process:
        torch.save(model_state, full_model_save_path)