Hi,
When I tried to call copy.deepcopy on a torch.distributed.fsdp.FullyShardedDataParallel model wrapper.
The following error was raised:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/ema_pytorch/ema_pytorch.py", line 59, in __init__
self.ema_model = copy.deepcopy(model)
File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.8/copy.py", line 296, in _reconstruct
value = deepcopy(value, memo)
File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.8/copy.py", line 296, in _reconstruct
value = deepcopy(value, memo)
File "/usr/lib/python3.8/copy.py", line 153, in deepcopy
y = copier(memo)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parameter.py", line 55, in __deepcopy__
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
File "/usr/local/lib/python3.8/dist-packages/torch/distributed/fsdp/flatten_params_wrapper.py", line 112, in __new__
raise ValueError("An non-empty list or tuple argument is needed")
ValueError: An non-empty list or tuple argument is needed
Parameter.__deepcopy__
will call fsdp.FlatParameter.__new__
when being deep copied, however self.data.clone()
(which is a Tensor) is passed to fsdp.FlatParameter.__new__
which expects a list or tuple as parameter.
Seems that Parameter.__deepcopy__
was not implemented correctly to copy fsdp.FlatParameter
?
Or if torch.distributed.fsdp.FullyShardedDataParallel
should not be deep copied anyway?