Calling copy.deepcopy on FSDP raised error

Hi,

When I tried to call copy.deepcopy on a torch.distributed.fsdp.FullyShardedDataParallel model wrapper.
The following error was raised:

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/ema_pytorch/ema_pytorch.py", line 59, in __init__
    self.ema_model = copy.deepcopy(model)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/usr/lib/python3.8/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/usr/lib/python3.8/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parameter.py", line 55, in __deepcopy__
    result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/flatten_params_wrapper.py", line 112, in __new__
    raise ValueError("An non-empty list or tuple argument is needed")
ValueError: An non-empty list or tuple argument is needed

Parameter.__deepcopy__ will call fsdp.FlatParameter.__new__ when being deep copied, however self.data.clone()(which is a Tensor) is passed to fsdp.FlatParameter.__new__ which expects a list or tuple as parameter.

Seems that Parameter.__deepcopy__ was not implemented correctly to copy fsdp.FlatParameter ?
Or if torch.distributed.fsdp.FullyShardedDataParallel should not be deep copied anyway?