Torch.distributed.dcp.save does not save on all ranks

I am training with FSDP 2.0 (using TRL) and am having issues getting sharded_state_dict saving to work. When I run this code for some reason not all ranks successfully save their distcp. I have included the stack trace bit I am at a complete loss as to why this isnt working now. Every once in a while it actually does work and all ranks successfully save. Additionally if I specify the device_ids as [0,1,2…] in dist.barrier it saves all .distcp files on all ranks but we hang at dist.barrier.

If there is a better way for me to do FSDP sharded checkpointing into torch.save formats please lmk. Preferably safetensors.


Linux

torch 2.6.0, cuda 12.4

H200x8 with FSDP 2.0


task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193346 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347 with SHARDED_STATE_DICT via torch.distributed.checkpointSaving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336) 
(task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345 with SHARDED_STATE_DICT via torch.distributed.checkpointSaving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336) 
(task, pid=4336) Saving to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347 with SHARDED_STATE_DICT via torch.distributed.checkpoint
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347/_dcp_tmp (DCP format) on rank 0
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345/_dcp_tmp (DCP format) on rank 2
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345/_dcp_tmp (DCP format) on rank 6
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347/_dcp_tmp (DCP format) on rank 4
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193346/_dcp_tmp (DCP format) on rank 1
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347/_dcp_tmp (DCP format) on rank 3Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193345/_dcp_tmp (DCP format) on rank 7
(task, pid=4336) 
(task, pid=4336) Saved model to /mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347/_dcp_tmp (DCP format) on rank 5
(task, pid=4336) passed barrier on rank 6passed barrier on rank 1passed barrier on rank 3passed barrier on rank 2passed barrier on rank 7passed barrier on rank 4
(task, pid=4336) 
(task, pid=4336) 
(task, pid=4336) 
(task, pid=4336) 
(task, pid=4336) 
(task, pid=4336) passed barrier on rank 0passed barrier on rank 5
(task, pid=4336) 
(task, pid=4336) [rank0]: Traceback (most recent call last):
(task, pid=4336) [rank0]:   File "/root/sky_workdir/train.py", line 508, in <module>
(task, pid=4336) [rank0]:     main(config)
(task, pid=4336) [rank0]:   File "/root/sky_workdir/train.py", line 436, in main
(task, pid=4336) [rank0]:     trainer.save_model(sft_config.output_dir)
(task, pid=4336) [rank0]:   File "/root/sky_workdir/train.py", line 316, in _patched_save_model
(task, pid=4336) [rank0]:     dcp_to_torch_save(dcp_tmp_dir, "tmp.pt")
(task, pid=4336) [rank0]:   File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/format_utils.py", line 212, in dcp_to_torch_save
(task, pid=4336) [rank0]:     _load_state_dict(
(task, pid=4336) [rank0]:   File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 240, in _load_state_dict
(task, pid=4336) [rank0]:     _ = distW.all_gather("read", read_data)
(task, pid=4336) [rank0]:   File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 256, in all_gather
(task, pid=4336) [rank0]:     raise CheckpointException(step, node_failures)
(task, pid=4336) [rank0]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
(task, pid=4336) [rank0]: Traceback (most recent call last): (RANK 0)
(task, pid=4336) [rank0]:   File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 248, in all_gather
(task, pid=4336) [rank0]:     result = map_fun()
(task, pid=4336) [rank0]:   File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
(task, pid=4336) [rank0]:     result = func(*args, **kwargs)
(task, pid=4336) [rank0]:   File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 235, in read_data
(task, pid=4336) [rank0]:     all_reads = storage_reader.read_data(final_local_plan, planner)
(task, pid=4336) [rank0]:   File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 643, in read_data
(task, pid=4336) [rank0]:     with self.fs.create_stream(new_path, "rb") as stream:
(task, pid=4336) [rank0]:   File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
(task, pid=4336) [rank0]:     return next(self.gen)
(task, pid=4336) [rank0]:   File "/root/venv/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 388, in create_stream
(task, pid=4336) [rank0]:     with cast(Path, path).open(mode) as stream:
(task, pid=4336) [rank0]:   File "/usr/lib/python3.10/pathlib.py", line 1119, in open
(task, pid=4336) [rank0]:     return self._accessor.open(self, mode, buffering, encoding, errors,
(task, pid=4336) [rank0]: FileNotFoundError: [Errno 2] No such file or directory: '/mnt/nebius-shared-fs/tmp/tmp-unimodel-qwen3-8b/tmp-unimodel-qwen3-8b-sft-2025-08-21-193347/_dcp_tmp/__1_0.distcp'

def patch_trainer_save_model(trainer):
    """Monkey-patch Trainer.save_model to handle FSDP SHARDED state dict types."""

    original_save_model = trainer.save_model

    def _patched_save_model(
        self, output_dir: Optional[str] = None, _internal_call: bool = False
    ):
        import os
        import shutil
        import torch.distributed.checkpoint as dcp
        import torch.distributed as dist

        if output_dir is None:
            output_dir = self.args.output_dir
        if self.is_fsdp_enabled:
            if "SHARDED_STATE_DICT" in str(
                self.accelerator.state.fsdp_plugin.state_dict_type
            ):
                print(
                    f"Saving to {output_dir} with SHARDED_STATE_DICT via torch.distributed.checkpoint"
                )


                # Save model in DCP format to a temporary subfolder inside output_dir (no uuid)
                dcp_tmp_dir = os.path.join(output_dir, "_dcp_tmp")

                # Create directory on rank 0 first, then sync
                if dist.get_rank() == 0:
                    os.makedirs(dcp_tmp_dir, exist_ok=True)
                dist.barrier()  # Ensure directory exists before all ranks save

                dcp.save({"model": self.model}, checkpoint_id=dcp_tmp_dir)

                print(
                    f"Saved model to {dcp_tmp_dir} (DCP format) on rank {dist.get_rank()}"
                )

                dist.barrier()

                rank = dist.get_rank()
                print(f"passed barrier on rank {dist.get_rank()}")

                if rank == 0:
                    from torch.distributed.checkpoint.filesystem import FileSystemReader
                    from torch.distributed.checkpoint.state_dict_loader import (
                        _load_state_dict,
                    )
                    from torch.distributed.checkpoint.default_planner import (
                        _EmptyStateDictLoadPlanner,
                    )

                    sd = {}
                    _load_state_dict(
                        sd,
                        storage_reader=FileSystemReader(dcp_tmp_dir),
                        planner=_EmptyStateDictLoadPlanner(),
                        no_dist=True,
                    )
                    sd = {
                        (k.split(".", 1)[1] if k.startswith("model.") else k): v
                        for k, v in sd.items()
                    }
                    shutil.rmtree(dcp_tmp_dir)
                    if self.args.should_save:
                        self._save(output_dir, state_dict=sd)
                        if self.args.push_to_hub and not _internal_call:
                            self.push_to_hub(
                                commit_message="Model save",
                                revision=self.args.hub_revision,
                            )
                dist.barrier()
                return
        # Non-FSDP or non-sharded: use original behavior
        return original_save_model(output_dir, _internal_call=_internal_call)

    trainer.save_model = types.MethodType(_patched_save_model, trainer)

Thanks for sharing the detailed traceback—super helpful. From the error it looks like some ranks aren’t flushing their .distcp files before rank 0 tries to gather. You might try forcing a dist.barrier() right after dcp.save per rank or switching to distcp.save_state_dict with an explicit planner. Also worth checking the new torch.distributed.checkpoint.FileSystemWriter APIs, which handle sharded FSDP saves more reliably and support safetensors.

1 Like

torch.distributed.checkpoint.FileSystemWriter did not seem to fix the issue. What do you mean by “Forcing a per rank dist.barrier” I though that is what I was doing by calling dist.barrier after dcp.save?