Hi,
I am following this test code to save and load optimizer states using load_sharded_optimizer_state_dict: pytorch/test_fsdp_optim_state.py at 09b3517297d888714750217edf6c1f76a6340d85 · pytorch/pytorch · GitHub
The resuming went well on 1 node and 8 nodes jobs. But when I tested it on 64 nodes with large model size, two errors appeared on two nodes
Here is first error
optim_state = dist_cp_op.load_sharded_optimizer_state_dict(
File "/opt/conda/lib/python3.9/site-packages/torch/distributed/checkpoint/optimizer.py", line 282, in load_sharded_optimizer_state_dict
state_dict[key] = _shard_tensor(
File "/opt/conda/lib/python3.9/site-packages/torch/distributed/_shard/api.py", line 70, in _shard_tensor
st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
File "/opt/conda/lib/python3.9/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 181, in shard
dist.scatter(
File "/opt/conda/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper
return func(*args, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2774, in scatter
_check_tensor_list(scatter_list, "scatter_list")
File "/opt/conda/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 599, in _check_tensor_list
raise RuntimeError(
RuntimeError: Invalid function argument. Expected parameter `scatter_list` to be of type List[torch.Tensor].
And second error
File "/src/shopqa/generative/workflows/seq2seq/train/snapshot.py", line 363, in load_model_optim_final
optim_state = dist_cp_op.load_sharded_optimizer_state_dict(
File "/opt/conda/lib/python3.9/site-packages/torch/distributed/checkpoint/optimizer.py", line 282, in load_sharded_optimizer_state_dict
state_dict[key] = _shard_tensor(
File "/opt/conda/lib/python3.9/site-packages/torch/distributed/_shard/api.py", line 70, in _shard_tensor
st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
File "/opt/conda/lib/python3.9/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 172, in shard
assert local_tensor is not None
AssertionError
My pytorch version: 2.0.1
I am wondering is there anything missing when I used load_sharded_optimizer_state_dict in large-scale training?