When I configure state_dict_type
like this:
with fsdp_model.state_dict_type(
fsdp_model,
StateDictType.LOCAL_STATE_DICT,
LocalStateDictConfig(offload_to_cpu=True),
LocalOptimStateDictConfig(offload_to_cpu=True)
):
and then call
model_state_dict = fsdp_model.state_dict()
an error will be raised here:
Local shards' tensor requires_grad property is incompatible with tensor property on rank 1: tensor property requires_grad=True, local shard tensor requires_grad=False.
After debugging, I found that the ShardedTensor
will create a CPU tensor with requires_grad=False
here by default. However, the requires_grads
is True in the sharded tensor of my model. This kind of inconsistency triggered the error.
I’m just curious about that is this case a bug or a feature? If it is a feathre, I just want to know why and how can I enable cpu_offload=True
when using LOCAL_STATE_DICT
The minimum code to reproduce the error
import os.path as osp
import tempfile
import torch
import torch.distributed as dist
import torchvision
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import LocalStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import \
LocalOptimStateDictConfig
from torch.distributed.fsdp.wrap import _module_wrap_policy
from torch.optim import SGD
def conv_wrap_policy(*args, module_classes=torch.nn.Conv2d, **kwargs):
return _module_wrap_policy(*args, module_classes=(module_classes, ), **kwargs)
def test_save_sharded_state_dict():
# Save a checkpoint
model = torchvision.models.resnet50()
rank = dist.get_rank()
fsdp_model = FSDP(model, use_orig_params=False, device_id=rank, auto_wrap_policy=conv_wrap_policy)
# fsdp_model.get_state_dict_type(fsdp_model)
optim = SGD(fsdp_model.parameters(), lr=0.001, momentum=0.9)
loss = fsdp_model(torch.ones(3, 3, 224, 224).cuda().float())
loss.backward()
optim.step()
with tempfile.TemporaryDirectory() as tmpdir:
with fsdp_model.state_dict_type(
fsdp_model,
StateDictType.LOCAL_STATE_DICT,
LocalStateDictConfig(offload_to_cpu=True),
LocalOptimStateDictConfig(offload_to_cpu=True)
):
# optim_state_dict = fsdp_model.full_optim_state_dict(fsdp_model, optim, rank0_only=True)
optim_state_dict = fsdp_model.optim_state_dict(fsdp_model, optim)
model_state_dict = fsdp_model.state_dict()
torch.save(dict(model=model_state_dict, optim=optim_state_dict), osp.join(tmpdir, f'ckpt_{rank}.pth'))
for p in model_state_dict.values():
assert p.device == torch.device('cpu')
ckpt = torch.load(f'ckpt_{rank}.pth')
# check state dict is the original format
fsdp_model.load_state_dict(ckpt['model'])
optim.load_state_dict(ckpt['optim'])
if __name__ == '__main__':
dist.init_process_group(backend='nccl')
torch.cuda.set_device(dist.get_rank())
# test_save_full_state_dict()
test_save_sharded_state_dict()