Hi,
Trying to find out why FSDP model does not want to save its optimizer state from within all processes in parallel:
from torch.distributed import fsdp
...
model = ... # torch.nn.Module
model = fsdp.FSDP(model)
optimizer = torch.optim.adam.AdamW(model)
...
state_dict_config = fsdp.FullStateDictConfig(
offload_to_cpu=True, rank0_only=True)
state_dict_type = fsdp.StateDictType.FULL_STATE_DICT
with fsdp.FullyShardedDataParallel.state_dict_type(
model,
state_dict_type=state_dict_type,
state_dict_config=state_dict_config):
model_state = model.state_dict()
optimizer_state = fsdp.FullyShardedDataParallel.full_optim_state_dict(
model=model, optim=optimizer, rank0_only=True)
fails with
File ".venv/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1331, in full_optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
File ".venv/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1154, in _optim_state_dict_impl
return _optim_state_dict(
File ".venv/lib/python3.8/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1455, in _optim_state_dict
_gather_orig_param_state(
File ".venv/lib/python3.8/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1697, in _gather_orig_param_state
value = value[: flat_param._numels[param_idx]].reshape(
RuntimeError: shape '[50304, 768]' is invalid for input of size 8
Meaning that the model state is collected, but not the optimizer’s one. Which is somewhat strange given that it follows examples in pytorch tests or here with mosaic.
Any clue what this could be if the torch version is 2.0.0
? Thank you very much in advance!