[FSDP] Cannot enable `offload_to_cpu=True` when using `LOCAL_STATE_DICT`

When I configure state_dict_type like this:

        with fsdp_model.state_dict_type(
            fsdp_model,
            StateDictType.LOCAL_STATE_DICT,
            LocalStateDictConfig(offload_to_cpu=True),
            LocalOptimStateDictConfig(offload_to_cpu=True)
        ):

and then call

model_state_dict = fsdp_model.state_dict()

an error will be raised here:

Local shards' tensor requires_grad property is incompatible with tensor property on rank 1: tensor property requires_grad=True, local shard tensor requires_grad=False. 

After debugging, I found that the ShardedTensor will create a CPU tensor with requires_grad=False here by default. However, the requires_grads is True in the sharded tensor of my model. This kind of inconsistency triggered the error.

I’m just curious about that is this case a bug or a feature? If it is a feathre, I just want to know why and how can I enable cpu_offload=True when using LOCAL_STATE_DICT

The minimum code to reproduce the error

import os.path as osp
import tempfile

import torch
import torch.distributed as dist
import torchvision
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import LocalStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import \
    LocalOptimStateDictConfig
from torch.distributed.fsdp.wrap import _module_wrap_policy
from torch.optim import SGD


def conv_wrap_policy(*args, module_classes=torch.nn.Conv2d, **kwargs):
    return _module_wrap_policy(*args, module_classes=(module_classes, ), **kwargs)



def test_save_sharded_state_dict():
    # Save a checkpoint
    model = torchvision.models.resnet50()
    rank = dist.get_rank()
    fsdp_model = FSDP(model, use_orig_params=False, device_id=rank, auto_wrap_policy=conv_wrap_policy)
    # fsdp_model.get_state_dict_type(fsdp_model)
    optim = SGD(fsdp_model.parameters(), lr=0.001, momentum=0.9)
    loss = fsdp_model(torch.ones(3, 3, 224, 224).cuda().float())
    loss.backward()
    optim.step()
    with tempfile.TemporaryDirectory() as tmpdir:
        with fsdp_model.state_dict_type(
            fsdp_model,
            StateDictType.LOCAL_STATE_DICT,
            LocalStateDictConfig(offload_to_cpu=True),
            LocalOptimStateDictConfig(offload_to_cpu=True)
        ):
            # optim_state_dict = fsdp_model.full_optim_state_dict(fsdp_model, optim, rank0_only=True)
            optim_state_dict = fsdp_model.optim_state_dict(fsdp_model, optim)
            model_state_dict = fsdp_model.state_dict()
            torch.save(dict(model=model_state_dict, optim=optim_state_dict), osp.join(tmpdir, f'ckpt_{rank}.pth'))

            for p in model_state_dict.values():
                assert p.device == torch.device('cpu')
            ckpt = torch.load(f'ckpt_{rank}.pth')
            # check state dict is the original format
            fsdp_model.load_state_dict(ckpt['model'])
            optim.load_state_dict(ckpt['optim'])



if __name__ == '__main__':
    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(dist.get_rank())
    # test_save_full_state_dict()
    test_save_sharded_state_dict()