I’m currently running experiment with Distributed Data Parallel, with batch normalization (not synchronized). I have two questions regarindg some issues:
- Since I am not synchronizing the batch norm, each model keeps different running means and running stats. However when I evaluate the model with different gpus, the result seems identical. Can somebody tell me how could this be happening?
Here is my code for evaluation:
def evaluate(test_loader, model, rank, epoch):
model.eval()
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
print(module.running_mean)
break
accuracy = 0.
cnt = 0.
with torch.no_grad():
first = True
for data in test_loader:
inputs, labels = data[0].to(rank), data[1].to(rank)
if epoch == 0 and first:
print(f"Val Batch Size: {inputs.shape[0]}")
first = False
preds = model(inputs)
accuracy += (torch.argmax(preds, 1) == labels).sum().item()
cnt += len(labels)
accuracy *= 100 / cnt
return accuracy
- Since evaluation on each device yields same result, I tried to evaluate only on a single model however got a following error:
Traceback (most recent call last):
File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/lthilnklover/sam_torch/parallel_test.py", line 119, in <module>
mp.spawn(main, args=(world_size, args), nprocs=world_size, join=True)
File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
fn(i, *args)
File "/home/lthilnklover/sam_torch/parallel_test.py", line 101, in main
accuracy = evaluate(test_loader, model, rank)
File "/home/lthilnklover/sam_torch/parallel_test.py", line 32, in evaluate
preds = model(inputs)
File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 791, in forward
self._sync_params()
File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1318, in _sync_params
self._distributed_broadcast_coalesced(
File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1278, in _distributed_broadcast_coalesced
dist._broadcast_coalesced(
RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:598] Connection closed by peer
But I have no clue why this is happening…