torch.distributed.checkpoint CUDA OOM with broadcast_from_rank0

Issue description

I am trying to load an FSDP checkpoint by broadcasting weights from rank 0. The model is already correctly set up on GPU on each rank. I use

model_state_dict = torch.distributed.checkpoint.state_dict.set_model_state_dict(
    model=self._model,
    model_state_dict=model_state_dict,
    options=torch.distributed.checkpoint.state_dict.StateDictOptions(
        full_state_dict=True,
        cpu_offload=True,
        ignore_frozen_params=False,
        broadcast_from_rank0=True,
    ),
)

When this call starts executing, I can see the CUDA memory on each GPU rapidly rising and from ~20GB → 40GB of memory per GPU on nvidia-smi. Eventually it fails with CUDA OOM (see stack trace below). When I set broadcast_from_rank0=False, it works fine.

How can I use broadcast_from_rank0=True without running out of memory?

Traceback (most recent call last):
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/components/scripts/train/train.py", line 19, i
n <module>
    main()
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/components/scripts/train/train.py", line 15, i
n main
    TrainJob().run(config)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/components/scripts/train/train_job.py", line 1
8, in run
    self.run_trainer(config)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/components/scripts/train/train_job.py", line 1
18, in run_trainer
    trainer = Trainer(config=config)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/core/trainer/trainer.py", line 102, in __init_
_
    self._maybe_restore_checkpoint()
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/core/trainer/trainer.py", line 124, in _maybe_
restore_checkpoint
    self.load_state_dict(state_dict)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/core/common/state_dict.py", line 96, in load_s
tate_dict
    load_state_dict_method(value)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/core/common/state_dict.py", line 82, in load_s
tate_dict
    self._load_custom_state_dict(state_dict)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/barrel/barrel/core/trainer/training_module.py", line 164, in
 _load_custom_state_dict
    model_state_dict = torch.distributed.checkpoint.state_dict.set_model_state_dict(
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/checkpoint/st
ate_dict.py", line 1184, in set_model_state_dict
    return _load_model_state_dict(model, model_state_dict, info)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/checkpoint/st
ate_dict.py", line 566, in _load_model_state_dict
    _state_dict_fn(model, "load_state_dict")(
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2201, in load_state_dict
    load(self, state_dict)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2189, in load
    load(child, child_state_dict, child_prefix)  # noqa: F821
  [Previous line repeated 2 more times]
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2183, in load
    module._load_from_state_dict(
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 2034, in _load_from_state_dict
    hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/nn/modules/module.py", li
ne 73, in __call__
    return self.hook(module, *args, **kwargs)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/utils/_contextlib.py", li
ne 116, in decorate_context
    return func(*args, **kwargs)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_state_d
ict_utils.py", line 849, in _pre_load_state_dict_hook
    _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_state_d
ict_utils.py", line 371, in _full_pre_load_state_dict_hook
    _enter_unshard_params_ctx(module, fsdp_state, writeback=True)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_state_d
ict_utils.py", line 139, in _enter_unshard_params_ctx
    fsdp_state._unshard_params_ctx[module].__enter__()
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/external/python_runtime_x86_64-unknown-linux-gnu/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_unshard
_param_utils.py", line 197, in _unshard_fsdp_state_params
    _unshard(state, handle, computation_stream, computation_stream)
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_runtime
_utils.py", line 300, in _unshard
    handle.unshard()
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_flat_pa
ram.py", line 1310, in unshard
    unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/fsdp/_flat_pa
ram.py", line 1337, in _alloc_padded_unsharded_flat_param
    _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size)  # type: ignore[attr-defined]
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/distributed/utils.py", li
ne 186, in _alloc_storage
    tensor._typed_storage()._resize_(size.numel())
  File "/scratch/nikolay_nikolov/.cache/bazel/_bazel_nikolay_nikolov/79bf5e678fbb2019f1e30944a206f079/execroot/barrel/bazel-out/k8-opt/bin/barrel/pipes/vlams/torchrun.runfiles/pip-core_torch/site-packages/torch/storage.py", line 1027, i
n _resize_
    self._untyped_storage.resize_(size * self._element_size())
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.16 GiB. GPU 0 has a total capacity of 39.38 GiB of which 547.38 MiB is free. Including non-PyTorch memory, this process has 38.84 GiB memory in use. Of the allocated memory
 35.91 GiB is allocated by PyTorch, and 363.02 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentati
on for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

When using broadcast_from_rank0, do all ranks except for rank0 have an empty state_dict? If yes, can you repost to Issues · pytorch/pytorch · GitHub? This sounds more like an issue/bug.