DDP and Spectral Normalization

I’m having trouble with spectral normalization when training a distributed model.
The code works fine on a single GPU, but as soon as I move to multiples GPU, the following message raises:

File "/usagers/clpla/NNTools/nntools/experiment/experiment.py", line 345, in start
    mp.spawn(self._start_process, nprocs=self.world_size, join=True)
  File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 179, in start_processes
  File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
  File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
  File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
  File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/site-packages/torch/nn/utils/parametrize.py", line 285, in getstate
    raise RuntimeError(
RuntimeError: Serialization of parametrized modules is only supported through state_dict(). See:
Exception ignored in: <function Connection.__del__ at 0x7fea8bff7ca0>

I simply use Pytorch’s implementation of spectral normalization:

model = my_model()
for m in model.modules():
    if type(m) == nn.Conv2d or type(m) == nn.Linear:

Any idea on how to fix this issue? I’m using Pytorch 1.10.2.

It looks like you are first initializing your module/spectral_norm and then trying to pass it through mp.spawn which is trying to pickle it and failing. I’d suggest running mp.spawn first and then within each process initialize your model. You can find simple examples of this here: Writing Distributed Applications with PyTorch — PyTorch Tutorials 1.10.1+cu102 documentation

1 Like

Hi, thanks for your answer and sorry for providing a late feedback,
I was trying to see how to adapt my code for this specific case, my pipeline wasn’t initially thought with this problem in mind.
Initializing the model after the call to mp.spawn did solve the problem, thanks!