DDP, Batch Normalization, and Evaluation

I’m currently running experiment with Distributed Data Parallel, with batch normalization (not synchronized). I have two questions regarindg some issues:

  1. Since I am not synchronizing the batch norm, each model keeps different running means and running stats. However when I evaluate the model with different gpus, the result seems identical. Can somebody tell me how could this be happening?

Here is my code for evaluation:

def evaluate(test_loader, model, rank, epoch):
    model.eval()
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            print(module.running_mean)
            break
    accuracy = 0.
    cnt = 0.

    with torch.no_grad():
        first = True
        for data in test_loader:
            inputs, labels = data[0].to(rank), data[1].to(rank)
            if epoch == 0 and first:
                print(f"Val Batch Size: {inputs.shape[0]}")
                first = False
            preds = model(inputs)

            accuracy += (torch.argmax(preds, 1) == labels).sum().item()
            cnt += len(labels)
    accuracy *= 100 / cnt
    return accuracy
  1. Since evaluation on each device yields same result, I tried to evaluate only on a single model however got a following error:
Traceback (most recent call last):
  File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/lthilnklover/sam_torch/parallel_test.py", line 119, in <module>
    mp.spawn(main, args=(world_size, args), nprocs=world_size, join=True)
  File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/lthilnklover/sam_torch/parallel_test.py", line 101, in main
    accuracy = evaluate(test_loader, model, rank)
  File "/home/lthilnklover/sam_torch/parallel_test.py", line 32, in evaluate
    preds = model(inputs)
  File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 791, in forward
    self._sync_params()
  File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1318, in _sync_params
    self._distributed_broadcast_coalesced(
  File "/home/lthilnklover/.conda/envs/torch_1.8/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1278, in _distributed_broadcast_coalesced
    dist._broadcast_coalesced(
RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:598] Connection closed by peer

But I have no clue why this is happening…