Issue description
I am trying to load an FSDP checkpoint by broadcasting weights from rank 0. The model is already correctly set up on GPU on each rank. I use
model_state_dict = torch.distributed.checkpoint.state_dict.set_model_state_dict(
model=self._model,
model_state_dict=model_state_dict,
options=torch.distributed.checkpoint.state_dict.StateDictOptions(
full_state_dict=True,
cpu_offload=True,
ignore_frozen_params=False,
broadcast_from_rank0=True,
),
)
When this call starts executing, I can see the CUDA memory on each GPU rapidly rising and from ~20GB → 40GB of memory per GPU on nvidia-smi
. Eventually it fails with CUDA OOM (see stack trace below). When I set broadcast_from_rank0=False
, it works fine.
How can I use broadcast_from_rank0=True
without running out of memory?
Traceback (most recent call last):
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/components/scripts/train/train.py", line 19, i
n <module>
main()
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/components/scripts/train/train.py", line 15, i
n main
TrainJob().run(config)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/components/scripts/train/train_job.py", line 1
8, in run
self.run_trainer(config)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/components/scripts/train/train_job.py", line 1
18, in run_trainer
trainer = Trainer(config=config)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/core/trainer/trainer.py", line 102, in __init_
_
self._maybe_restore_checkpoint()
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/core/trainer/trainer.py", line 124, in _maybe_
restore_checkpoint
self.load_state_dict(state_dict)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/core/common/state_dict.py", line 96, in load_s
tate_dict
load_state_dict_method(value)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/core/common/state_dict.py", line 82, in load_s
tate_dict
self._load_custom_state_dict(state_dict)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/core/trainer/training_module.py", line 164, in
_load_custom_state_dict
model_state_dict = torch.distributed.checkpoint.state_dict.set_model_state_dict(
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/checkpoint/st
ate_dict.py", line 1184, in set_model_state_dict
return _load_model_state_dict(model, model_state_dict, info)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/checkpoint/st
ate_dict.py", line 566, in _load_model_state_dict
_state_dict_fn(model, "load_state_dict")(
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2201, in load_state_dict
load(self, state_dict)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2189, in load
load(child, child_state_dict, child_prefix) # noqa: F821
[Previous line repeated 2 more times]
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2183, in load
module._load_from_state_dict(
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2034, in _load_from_state_dict
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 73, in __call__
return self.hook(module, *args, **kwargs)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/utils/_contextlib.py", li
ne 116, in decorate_context
return func(*args, **kwargs)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_state_d
ict_utils.py", line 849, in _pre_load_state_dict_hook
_pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_state_d
ict_utils.py", line 371, in _full_pre_load_state_dict_hook
_enter_unshard_params_ctx(module, fsdp_state, writeback=True)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_state_d
ict_utils.py", line 139, in _enter_unshard_params_ctx
fsdp_state._unshard_params_ctx[module].__enter__()
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/external/python_runtime_x86_64-unknown-linux-gnu/lib/python3.10/contextlib.py", line 135, in __enter__
return next(self.gen)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_unshard
_param_utils.py", line 197, in _unshard_fsdp_state_params
_unshard(state, handle, computation_stream, computation_stream)
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_runtime
_utils.py", line 300, in _unshard
handle.unshard()
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_flat_pa
ram.py", line 1310, in unshard
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_flat_pa
ram.py", line 1337, in _alloc_padded_unsharded_flat_param
_alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/utils.py", li
ne 186, in _alloc_storage
tensor._typed_storage()._resize_(size.numel())
File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/storage.py", line 1027, i
n _resize_
self._untyped_storage.resize_(size * self._element_size())
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.16 GiB. GPU 0 has a total capacity of 39.38 GiB of which 547.38 MiB is free. Including non-PyTorch memory, this process has 38.84 GiB memory in use. Of the allocated memory
35.91 GiB is allocated by PyTorch, and 363.02 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentati
on for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)