Can someone shed light on the following?
python 3.12.9
torch 2.7.0+cu126
NVIDIA Driver 575.57.08
torchrun --standalone --nproc-per-node=2 bug_fsdp.py
Questions:
-
Why do I have to wrap torch.nn.Embedding when using distributed backend=“nccl”? If I comment out line in custom_wrap_policy specifying torch.nn.Embedding then FSDP will hang. Wrapping torch.nn.Embedding does not hang. (Using “gloo” it will not hang either with or withoug torch.nn.Embedding)
-
I stumbled across the issue with Embedding by accident. What is it about the Llama model that I should have seen that would have told me this was necessary? What should I look for in other models I might try.
-
It seems like if “embed_tokens” is good to shard, then “lm_head” would be also. How would I shard it? The wrap policy only take a class (“Linear”) and I might not always want to shard all "Linear"s but I don’t see a why to specify sharding candidates by name (i.e. “lm_head”)
import os
import torch
import torch.distributed as dist
import torch.distributed.fsdp as FSDP1
import transformers
model_dir = 'NousResearch/Llama-3.2-1B'
if __name__ == "__main__" :
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
assert torch.cuda.is_available(), "cuda not available"
torch.cuda.set_device(local_rank)
device = torch.device("cuda:{}".format(local_rank))
dist.init_process_group(rank=rank, world_size=world_size, device_id=device, backend="nccl")
def load_model(model_dir) :
config = transformers.AutoConfig.from_pretrained(model_dir,
force_download=False,
local_files_only=True,
trust_remote_code=False,
)
model = transformers.AutoModelForCausalLM.from_config(config)
return config, model
print("[{}] loading model".format(rank))
dist.barrier()
if 0==rank : # load actual copy into CPU
with torch.device("cpu"):
config, model = load_model(model_dir)
else: # load to meta device so we don't use any memory
with torch.device("meta"):
config, model = load_model(model_dir)
print("[{}] loaded model to device {}".format(rank,model.device))
dist.barrier()
def custom_wrap_policy(module, recurse, nonwrapped_numel):
if recurse:
return True
else:
return isinstance(module,
(
transformers.models.llama.modeling_llama.LlamaDecoderLayer,
torch.nn.Embedding # !!! without wrapping Embedding, FSDP hangs!
)
)
def llama_init_fn(module):
if isinstance(module, transformers.models.llama.modeling_llama.LlamaRMSNorm) :
#print("[{}] llama_init_fn(LlamaRMSNorm)".format(rank))
pass
elif isinstance(module, transformers.models.llama.modeling_llama.LlamaRotaryEmbedding) :
#print("[{}] llama_init_fn(LlamaRotaryEmbedding)".format(rank))
pass
elif isinstance(module, torch.nn.Embedding) :
#print("[{}] llama_init_fn(Embedding)".format(rank))
pass
else :
pass
module.to_empty(device=device, recurse=False),
print("[{}] Wrapping FSDP".format(rank))
dist.barrier()
fsdp_model = FSDP1.FullyShardedDataParallel(
model,
sharding_strategy=FSDP1.ShardingStrategy.FULL_SHARD,
auto_wrap_policy=custom_wrap_policy,
ignored_modules=None,
param_init_fn=llama_init_fn, # only runs on a "meta" modules (recursively)
device_id=device, # sharded model gets created on 'device'
sync_module_states=True, # fill in "meta" models from rank0
)
dist.barrier()
print("[{}] FSDP model on device {}".format(rank,fsdp_model.device))
dist.barrier()
if dist.is_initialized() :
dist.destroy_process_group()
Example that does not hang (wraps torch.nn.Embedding)
[gpu:training] torchrun --standalone --nproc-per-node=2 bug_fsdp.py
[1] loading model
[0] loading model
[0] loaded model to device cpu
[1] loaded model to device meta
[0] Wrapping FSDP
[1] Wrapping FSDP
[1] FSDP model on device cuda:1
[0] FSDP model on device cuda:0
Example that hangs (does not wrap torch.nn.Embedding)
[gpu:training] torchrun --standalone --nproc-per-node=2 bug_fsdp.py
[0] loading model
[1] loading model
[0] loaded model to device cpu
[1] loaded model to device meta
[0] Wrapping FSDP
[1] Wrapping FSDP
[rank1]:[E720 15:08:06.804992647 ProcessGroupNCCL.cpp:632] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=23, OpType=BROADCAST, NumelIn=262670336, NumelOut=262670336, Timeout(ms)=600000) ran for 600023 milliseconds before timing out.
[rank1]:[E720 15:08:06.805686296 ProcessGroupNCCL.cpp:2268] [PG ID 0 PG GUID 0(default_pg) Rank 1] failure detected by watchdog at work sequence id: 23 PG status: last enqueued work: 24, last completed work: 22
[rank1]:[E720 15:08:06.805721295 ProcessGroupNCCL.cpp:670] Stack trace of the failed collective not found, potentially because FlightRecorder is disabled. You can enable it by setting TORCH_NCCL_TRACE_BUFFER_SIZE to a non-zero value.
[rank1]:[E720 15:08:06.805909995 ProcessGroupNCCL.cpp:2103] [PG ID 0 PG GUID 0(default_pg) Rank 1] First PG on this rank to signal dumping.
[rank1]:[E720 15:08:06.016934898 ProcessGroupNCCL.cpp:1743] [PG ID 0 PG GUID 0(default_pg) Rank 1] Received a dump signal due to a collective timeout from this local rank and we will try our best to dump the debug info. Last enqueued NCCL work: 24, last completed NCCL work: 22.This is most likely caused by incorrect usages of collectives, e.g., wrong sizes used across ranks, the order of collectives is not same for all ranks or the scheduled collective, for some reason, didn't run. Additionally, this can be caused by GIL deadlock or other reasons such as network errors or bugs in the communications library (e.g. NCCL), etc.
[rank1]:[E720 15:08:06.017320737 ProcessGroupNCCL.cpp:1533] [PG ID 0 PG GUID 0(default_pg) Rank 1] ProcessGroupNCCL preparing to dump debug info. Include stack trace: 1
[0] FSDP model on device cuda:0
[1] FSDP model on device cuda:1
[rank1]:[E720 15:08:07.571970177 ProcessGroupNCCL.cpp:684] [Rank 1] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank1]:[E720 15:08:07.571994037 ProcessGroupNCCL.cpp:698] [Rank 1] To avoid data inconsistency, we are taking the entire process down.
[rank1]:[E720 15:08:07.573179334 ProcessGroupNCCL.cpp:1896] [PG ID 0 PG GUID 0(default_pg) Rank 1] Process group watchdog thread terminated with exception: [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=23, OpType=BROADCAST, NumelIn=262670336, NumelOut=262670336, Timeout(ms)=600000) ran for 600023 milliseconds before timing out.
Exception raised from checkTimeout at /pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:635 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x98 (0x7f94c7f435e8 in /opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x23d (0x7f94c925ea1d in /opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::watchdogHandler() + 0xc80 (0x7f94c92607a0 in /opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x14d (0x7f94c9261ead in /opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xdbbf4 (0x7f94b9225bf4 in /opt/AI/training-4.52.4/bin/../lib/libstdc++.so.6)
frame #5: <unknown function> + 0x8a19a (0x7f9530c8a19a in /lib64/libc.so.6)
frame #6: <unknown function> + 0x10f210 (0x7f9530d0f210 in /lib64/libc.so.6)
[rank0]:[E720 15:08:07.649653797 ProcessGroupNCCL.cpp:1682] [PG ID 0 PG GUID 0(default_pg) Rank 0] Observed flight recorder dump signal from another rank via TCPStore.
[rank0]:[E720 15:08:07.649991016 ProcessGroupNCCL.cpp:1743] [PG ID 0 PG GUID 0(default_pg) Rank 0] Received a dump signal due to a collective timeout from rank 1 and we will try our best to dump the debug info. Last enqueued NCCL work: 25, last completed NCCL work: 24.This is most likely caused by incorrect usages of collectives, e.g., wrong sizes used across ranks, the order of collectives is not same for all ranks or the scheduled collective, for some reason, didn't run. Additionally, this can be caused by GIL deadlock or other reasons such as network errors or bugs in the communications library (e.g. NCCL), etc.
[rank0]:[E720 15:08:07.650321505 ProcessGroupNCCL.cpp:1533] [PG ID 0 PG GUID 0(default_pg) Rank 0] ProcessGroupNCCL preparing to dump debug info. Include stack trace: 1
W0720 15:08:09.430000 1583972 /opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 1583984 closing signal SIGTERM
E0720 15:08:12.050000 1583972 /opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: -6) local_rank: 1 (pid: 1583985) of binary: /opt/AI/training-4.52.4/bin/python
Traceback (most recent call last):
File "/opt/AI/training-4.52.4/bin/torchrun", line 8, in <module>
sys.exit(main())
^^^^^^
File "/opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/distributed/run.py", line 892, in main
run(args)
File "/opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/distributed/run.py", line 883, in run
elastic_launch(
File "/opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 139, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/AI/training-4.52.4/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 270, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: