I am training with FSDP 2.0 (using TRL) and am having issues getting sharded_state_dict saving to work. When I run this code for some reason not all ranks successfully save their distcp. I have included the stack trace bit I am at a complete loss as to why this isnt working now. Every once in a while it actually does work and all ranks successfully save. Additionally if I specify the device_ids as [0,1,2…] in dist.barrier it saves all .distcp files on all ranks but we hang at dist.barrier.
If there is a better way for me to do FSDP sharded checkpointing into torch.save formats please lmk. Preferably safetensors.
Linux
torch 2.6.0, cuda 12.4
H200x8 with FSDP 2.0
task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193346 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347 with SHARDED_STATE_DICT via torch.distributed.checkpointSaving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336)
(task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345 with SHARDED_STATE_DICT via torch.distributed.checkpointSaving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336)
(task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347/_dcp_tmp (DCP format) on rank 0
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345/_dcp_tmp (DCP format) on rank 2
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345/_dcp_tmp (DCP format) on rank 6
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347/_dcp_tmp (DCP format) on rank 4
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193346/_dcp_tmp (DCP format) on rank 1
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347/_dcp_tmp (DCP format) on rank 3Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345/_dcp_tmp (DCP format) on rank 7
(task, pid=4336)
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347/_dcp_tmp (DCP format) on rank 5
(task, pid=4336) passed barrier on rank 6passed barrier on rank 1passed barrier on rank 3passed barrier on rank 2passed barrier on rank 7passed barrier on rank 4
(task, pid=4336)
(task, pid=4336)
(task, pid=4336)
(task, pid=4336)
(task, pid=4336)
(task, pid=4336) passed barrier on rank 0passed barrier on rank 5
(task, pid=4336)
(task, pid=4336) [rank0]: Traceback (most recent call last):
(task, pid=4336) [rank0]: File "/root/sky_workdir/train.py", line 508, in <module>
(task, pid=4336) [rank0]: main(config)
(task, pid=4336) [rank0]: File "/root/sky_workdir/train.py", line 436, in main
(task, pid=4336) [rank0]: trainer.save_model(sft_config.output_dir)
(task, pid=4336) [rank0]: File "/root/sky_workdir/train.py", line 316, in _patched_save_model
(task, pid=4336) [rank0]: dcp_to_torch_save(dcp_tmp_dir, "tmp.pt")
(task, pid=4336) [rank0]: File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/format_utils.py", line 212, in dcp_to_torch_save
(task, pid=4336) [rank0]: _load_state_dict(
(task, pid=4336) [rank0]: File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 240, in _load_state_dict
(task, pid=4336) [rank0]: _ = distW.all_gather("read", read_data)
(task, pid=4336) [rank0]: File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 256, in all_gather
(task, pid=4336) [rank0]: raise CheckpointException(step, node_failures)
(task, pid=4336) [rank0]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
(task, pid=4336) [rank0]: Traceback (most recent call last): (RANK 0)
(task, pid=4336) [rank0]: File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 248, in all_gather
(task, pid=4336) [rank0]: result = map_fun()
(task, pid=4336) [rank0]: File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
(task, pid=4336) [rank0]: result = func(*args, **kwargs)
(task, pid=4336) [rank0]: File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 235, in read_data
(task, pid=4336) [rank0]: all_reads = storage_reader.read_data(final_local_plan, planner)
(task, pid=4336) [rank0]: File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 643, in read_data
(task, pid=4336) [rank0]: with self.fs.create_stream(new_path, "rb") as stream:
(task, pid=4336) [rank0]: File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
(task, pid=4336) [rank0]: return next(self.gen)
(task, pid=4336) [rank0]: File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 388, in create_stream
(task, pid=4336) [rank0]: with cast(Path, path).open(mode) as stream:
(task, pid=4336) [rank0]: File "/usr/lib/python3.10/pathlib.py", line 1119, in open
(task, pid=4336) [rank0]: return self._accessor.open(self, mode, buffering, encoding, errors,
(task, pid=4336) [rank0]: FileNotFoundError: [Errno 2] No such file or directory: '/mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347/_dcp_tmp/__1_0.distcp'
def patch_trainer_save_model(trainer):
"""Monkey-patch Trainer.save_model to handle FSDP SHARDED state dict types."""
original_save_model = trainer.save_model
def _patched_save_model(
self, output_dir: Optional[str] = None, _internal_call: bool = False
):
import os
import shutil
import torch.distributed.checkpoint as dcp
import torch.distributed as dist
if output_dir is None:
output_dir = self.args.output_dir
if self.is_fsdp_enabled:
if "SHARDED_STATE_DICT" in str(
self.accelerator.state.fsdp_plugin.state_dict_type
):
print(
f"Saving to {output_dir} with SHARDED_STATE_DICT via torch.distributed.checkpoint"
)
# Save model in DCP format to a temporary subfolder inside output_dir (no uuid)
dcp_tmp_dir = os.path.join(output_dir, "_dcp_tmp")
# Create directory on rank 0 first, then sync
if dist.get_rank() == 0:
os.makedirs(dcp_tmp_dir, exist_ok=True)
dist.barrier() # Ensure directory exists before all ranks save
dcp.save({"model": self.model}, checkpoint_id=dcp_tmp_dir)
print(
f"Saved model to {dcp_tmp_dir} (DCP format) on rank {dist.get_rank()}"
)
dist.barrier()
rank = dist.get_rank()
print(f"passed barrier on rank {dist.get_rank()}")
if rank == 0:
from torch.distributed.checkpoint.filesystem import FileSystemReader
from torch.distributed.checkpoint.state_dict_loader import (
_load_state_dict,
)
from torch.distributed.checkpoint.default_planner import (
_EmptyStateDictLoadPlanner,
)
sd = {}
_load_state_dict(
sd,
storage_reader=FileSystemReader(dcp_tmp_dir),
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)
sd = {
(k.split(".", 1)[1] if k.startswith("model.") else k): v
for k, v in sd.items()
}
shutil.rmtree(dcp_tmp_dir)
if self.args.should_save:
self._save(output_dir, state_dict=sd)
if self.args.push_to_hub and not _internal_call:
self.push_to_hub(
commit_message="Model save",
revision=self.args.hub_revision,
)
dist.barrier()
return
# Non-FSDP or non-sharded: use original behavior
return original_save_model(output_dir, _internal_call=_internal_call)
trainer.save_model = types.MethodType(_patched_save_model, trainer)