CUDA OOM error when loading sharded checkpoint

Hi all,

We fine-tuned Stability’s StableLM-7b using Huggingface’s Trainer API (with FSDP) and then saved the resulting checkpoints in the sharded format that is typical for large language models. Quite surprisingly, however, attempting to load the model for inference leads to a strange error when loading one of the checkpoints (Unable to load weights from pytorch checkpoint file)

We took some further investigative steps by making a simple torch.load call on the problem shard, and got a typical CUDA OOM error. The exceedingly strange thing about this OOM error is that we are working with a node with 8xA100s (80GB), and the given state dict is only 171kB (comprising only 7 layers of the model). So, you can imagine seeing the following error was quite a shock:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 29.31 GiB (GPU 0; 79.19 GiB total capacity; 55.76 GiB already allocated; 22.48 GiB free; 55.76 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

After looking into this further, I discovered a few threads discussing this issue, like this one, and attempted some of the fixes, namely loading the state dict on CPU first. After doing so, I received the following error:
RuntimeError: Trying to resize storage that is not resizable

So it seems that approach is out of the question. As I previously said, the strange thing here is that the first two shards load without issue, while the third and fourth cannot be loaded. Additionally, nothing seems particularly out of place in the shard-layer mapping JSON. I am stumped here. Perhaps it’s an error in Huggingface Trainer’s save function, but I figured I would start at the PyTorch source first.

1 Like

It is also worth adding that the given device in the error (cuda:0) is completely empty when torch.load is called