@ptrblck I am currently running into another problem that is closely related. After training with SWA, we need to update the batch norm statistics (reference). Since the structure of my dataset is different from what torch.optim.swa_utils.update_bn()
expects, I am doing the following inside train()
(recall that train()
is the launcher I provide to mp.spawn()
):
if rank == 0:
for batch in train_loader:
image = batch['image'].cuda(rank, non_blocking=True)
prediction = swa_model(image)
This leads to the following error:
Traceback (most recent call last):
File "starter.py", line 343, in <module>
nprocs=WORLD_SIZE, join=True
File "/home/jupyter/.local/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/jupyter/.local/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
while not context.join():
File "/home/jupyter/.local/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 118, in join
raise Exception(msg)
Exception:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/jupyter/.local/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
fn(i, *args)
File "/home/jupyter/Flood_Comp/starter.py", line 334, in train
prediction = swa_model(image)
File "/home/jupyter/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/jupyter/.local/lib/python3.7/site-packages/torch/optim/swa_utils.py", line 101, in forward
return self.module(*args, **kwargs)
File "/home/jupyter/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/jupyter/.local/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 610, in forward
self._sync_params()
File "/home/jupyter/.local/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 1048, in _sync_params
authoritative_rank,
File "/home/jupyter/.local/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 979, in _distributed_broadcast_coalesced
self.process_group, tensors, buffer_size, authoritative_rank
RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:575] Connection closed by peer [10.138.0.33]:26791
Anything I am missing out on?