I am loading the model weights and optimizer from a non FSDP training session and trying to apply it to an FSDP model, but I am getting the error.
File "/lustre07/scratch/kleingeo/VertDetect/trainer/Trainer.py", line 566, in train_model
sharded_optim_state_dict = FSDP.scatter_full_optim_state_dict(full_optim_state_dict, net)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kleingeo/vert_detect/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1617, in scatter_full_optim_state_dict
return FullyShardedDataParallel._optim_state_dict_to_load_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kleingeo/vert_detect/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1319, in _optim_state_dict_to_load_impl
return _rekey_sharded_optim_state_dict(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kleingeo/vert_detect/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 930, in _rekey_sharded_optim_state_dict
key.unflat_param_names, key.unflat_param_names
^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'unflat_param_names'
[2024-04-04 10:35:35,613] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-04-04 10:35:35,613] torch._dynamo.utils: [INFO] Function, Runtimes (s)
The loading of the optimizer is as follows:
model_checkpoint = torch.load('/model/checkpoint/dict/path.pt/, map_location='cpu')
net = model_init(**model_parameters)
net = net.to(self.device)
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.float32)
net = FSDP(net,
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
mixed_precision=mixed_precision,
device_id=gpu,
auto_wrap_policy=ModuleWrapPolicy(module_classes={model_init}),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
)
optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate)
full_optim_state_dict = None
if gpu == 0:
full_optim_state_dict = model_checkpoint['optimizer']
sharded_optim_state_dict = FSDP.scatter_full_optim_state_dict(full_optim_state_dict, net)
optimizer.load_state_dict(sharded_optim_state_dict)
Both the original model/optimizer and this are using the same conda environment, and both on torch==2.2.1
.