AttributeError: 'int' object has no attribute 'unflat_param_names' in FSDP.scatter_full_optim_state_dict

I am loading the model weights and optimizer from a non FSDP training session and trying to apply it to an FSDP model, but I am getting the error.

  File "/lustre07/scratch/kleingeo/VertDetect/trainer/Trainer.py", line 566, in train_model
    sharded_optim_state_dict = FSDP.scatter_full_optim_state_dict(full_optim_state_dict, net)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kleingeo/vert_detect/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1617, in scatter_full_optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_to_load_impl(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kleingeo/vert_detect/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1319, in _optim_state_dict_to_load_impl
    return _rekey_sharded_optim_state_dict(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kleingeo/vert_detect/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 930, in _rekey_sharded_optim_state_dict
    key.unflat_param_names, key.unflat_param_names
    ^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'unflat_param_names'

[2024-04-04 10:35:35,613] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-04-04 10:35:35,613] torch._dynamo.utils: [INFO] Function, Runtimes (s)

The loading of the optimizer is as follows:


model_checkpoint = torch.load('/model/checkpoint/dict/path.pt/, map_location='cpu')

net = model_init(**model_parameters)
net = net.to(self.device)

mixed_precision = MixedPrecision(param_dtype=torch.bfloat16,
                                                     reduce_dtype=torch.bfloat16,
                                                     buffer_dtype=torch.float32)


net = FSDP(net,
                    sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
                    mixed_precision=mixed_precision,
                    device_id=gpu,
                    auto_wrap_policy=ModuleWrapPolicy(module_classes={model_init}),
                    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
                    )

optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate)

full_optim_state_dict = None
if gpu == 0:
    full_optim_state_dict = model_checkpoint['optimizer']

sharded_optim_state_dict = FSDP.scatter_full_optim_state_dict(full_optim_state_dict, net)

optimizer.load_state_dict(sharded_optim_state_dict)

Both the original model/optimizer and this are using the same conda environment, and both on torch==2.2.1.