Hi,
I’m having trouble with spectral normalization when training a distributed model.
The code works fine on a single GPU, but as soon as I move to multiples GPU, the following message raises:
File "/usagers/clpla/NNTools/nntools/experiment/experiment.py", line 345, in start
mp.spawn(self._start_process, nprocs=self.world_size, join=True)
File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 179, in start_processes
process.start()
File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
return Popen(process_obj)
File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
File "/usagers/clpla/.conda/envs/torch18/lib/python3.8/site-packages/torch/nn/utils/parametrize.py", line 285, in getstate
raise RuntimeError(
RuntimeError: Serialization of parametrized modules is only supported through state_dict(). See:
https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training
Exception ignored in: <function Connection.__del__ at 0x7fea8bff7ca0>
I simply use Pytorch’s implementation of spectral normalization:
model = my_model()
for m in model.modules():
if type(m) == nn.Conv2d or type(m) == nn.Linear:
nn.utils.parametrizations.spectral_norm(m)
Any idea on how to fix this issue? I’m using Pytorch 1.10.2.
Thanks!