Unable to resume job with 64 nodes, errors appeared during loading sharded optimizer state dict

Hi,

I am following this test code to save and load optimizer states using load_sharded_optimizer_state_dict: pytorch/test_fsdp_optim_state.py at 09b3517297d888714750217edf6c1f76a6340d85 · pytorch/pytorch · GitHub

The resuming went well on 1 node and 8 nodes jobs. But when I tested it on 64 nodes with large model size, two errors appeared on two nodes
Here is first error

optim_state = dist_cp_op.load_sharded_optimizer_state_dict(
    File "/opt/conda/lib/python3.9/site-packages/torch/distributed/checkpoint/optimizer.py", line 282, in load_sharded_optimizer_state_dict
      state_dict[key] = _shard_tensor(
    File "/opt/conda/lib/python3.9/site-packages/torch/distributed/_shard/api.py", line 70, in _shard_tensor
      st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
    File "/opt/conda/lib/python3.9/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 181, in shard
      dist.scatter(
    File "/opt/conda/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper
      return func(*args, **kwargs)
    File "/opt/conda/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2774, in scatter
      _check_tensor_list(scatter_list, "scatter_list")
    File "/opt/conda/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 599, in _check_tensor_list
      raise RuntimeError(
  RuntimeError: Invalid function argument. Expected parameter `scatter_list` to be of type List[torch.Tensor].

And second error

File "/src/shopqa/generative/workflows/seq2seq/train/snapshot.py", line 363, in load_model_optim_final
      optim_state = dist_cp_op.load_sharded_optimizer_state_dict(
    File "/opt/conda/lib/python3.9/site-packages/torch/distributed/checkpoint/optimizer.py", line 282, in load_sharded_optimizer_state_dict
      state_dict[key] = _shard_tensor(
    File "/opt/conda/lib/python3.9/site-packages/torch/distributed/_shard/api.py", line 70, in _shard_tensor
      st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
    File "/opt/conda/lib/python3.9/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py", line 172, in shard
      assert local_tensor is not None
  AssertionError

My pytorch version: 2.0.1
I am wondering is there anything missing when I used load_sharded_optimizer_state_dict in large-scale training?