How does batch normalization work with multiple GPUs

I am going to use 2 GPUs to do data parallel training, and the model has batch normalization. I am wondering how pytorch handle BN with 2 GPUs. Does each GPU estimate the mean and variance separately? Suppose at test time, I will only use one GPU, then which mean and variance will pytorch use?

2 Likes

Do you have the answer right now?

According to the document of PyTorch, the batch norm performs over mini-batch, namely, per GPU

See https://pytorch.org/docs/stable/nn.html#torch.nn.SyncBatchNorm, with DistributedDataParallel and SyncBatchNorm then BN can be performed on multiple GPUs.