Implementing Batchnorm in Pytorch. Problem with updating self.running_mean and self.running_var

register_buffer makes sure to add tensors, which do not require gradients, to the internal state_dict, so that they can be saved and restored.

I’ve reviewed your implementation of BatchNorm2d. However, since I’m trying to implement it from the scratch (not inheriting anything from the built-in nn.BatchNorm), there’re things still unclear to me.

I slightly modified BatchNorm as follows.

class BatchNorm(nn.Module):
    def __init__(self, input, mode, momentum=0.9, epsilon=1e-05):
        input: assume 4D input (mini_batch_size, # channel, w, h)
        momentum: momentum for exponential average
        super(BatchNorm, self).__init__()
        self.momentum = momentum
        self.run_mode = 0 # 0: training, 1: testing
        self.insize = input
        self.epsilon = epsilon

        # initialize weight(gamma), bias(beta), running mean and variance
        U = uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
        self.weight = nn.Parameter(U.sample(torch.Size([self.insize])).view(self.insize))
        self.bias = nn.Parameter(torch.zeros(self.insize))
        self.register_buffer('running_mean', torch.zeros(self.insize)) # this solves cpu and cuda mismatch location issue
        self.register_buffer('running_var', torch.ones(self.insize))

        # self.running_mean = torch.zeros(self.insize) # torch.zeros(self.insize)
        # self.running_var = torch.ones(self.insize)


    def reset_parameters(self):

    def forward(self, input, mode):
        if mode == 0:
            mean = input.mean([0,2,3]) # along channel axis
            var = input.var([0,2,3])
            self.running_mean = (self.momentum * self.running_mean) + (1.0-self.momentum) * mean # .to(input.device)
            self.running_var = (self.momentum * self.running_var) + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*var)

            mean = self.running_mean
            var = self.running_var

        # change shape
        current_mean = mean.view([1, self.insize, 1, 1]).expand_as(input)
        current_var = var.view([1, self.insize, 1, 1]).expand_as(input)
        current_weight = self.weight.view([1, self.insize, 1, 1]).expand_as(input)
        current_bias = self.bias.view([1, self.insize, 1, 1]).expand_as(input)

        # get output
        y = current_weight * (input - current_mean) / (current_var + self.epsilon).sqrt() + current_bias

        return y

In the training process, which goes like below,

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets =,
        outputs = net(inputs, mode=0)
        loss = criterion(outputs, targets)

whenever I hit the line outputs = net(inputs, mode=0) I see running mean and var gets calculated and weight and bias get updated. However, as soon as I return back to and hit loss = criterion(outputs, targets), running mean and var get initialized again to 0 and 1.

ps. I find it super weird since I’ve checked updated running mean and var are kept well when I use only single GPU. This issue happens when I try to use multiple GPU with nn.DataParallel

For multiple GPUs, the running estimates of the default device should be used.
Could you post an executable code snippet, which reproduces this issue?

Here’s the executable code snippet that reproduces the issue I’m having.
As I print out running mean and variance during forward() step, I see my BatchNorm(bn1) somehow does not gets updated within my network.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import uniform
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

class BatchNorm(nn.Module):
    def __init__(self, input, mode, momentum=0.9, epsilon=1e-05):
        super(BatchNorm, self).__init__()
        self.momentum = momentum
        self.run_mode = 0 # 0: training, 1: testing
        self.insize = input
        self.epsilon = epsilon

        # initialize weight(gamma), bias(beta), running mean and variance
        U = uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
        self.weight = nn.Parameter(U.sample(torch.Size([self.insize])).view(self.insize))
        self.bias = nn.Parameter(torch.zeros(self.insize))
        self.register_buffer('running_mean', torch.zeros(self.insize)) # this solves cpu and cuda mismatch location issue
        self.register_buffer('running_var', torch.ones(self.insize))

    def reset_parameters(self):

    def forward(self, input, mode):
        if mode == 0:
            mean = input.mean([0,2,3]) # along channel axis
            var = input.var([0,2,3])
            self.running_mean = (self.momentum * self.running_mean) + (1.0-self.momentum) * mean # .to(input.device)
            self.running_var = (self.momentum * self.running_var) + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*var)

            mean = self.running_mean
            var = self.running_var

        # change shape
        current_mean = mean.view([1, self.insize, 1, 1]).expand_as(input)
        current_var = var.view([1, self.insize, 1, 1]).expand_as(input)
        current_weight = self.weight.view([1, self.insize, 1, 1]).expand_as(input)
        current_bias = self.bias.view([1, self.insize, 1, 1]).expand_as(input)

        # get output
        y = current_weight * (input - current_mean) / (current_var + self.epsilon).sqrt() + current_bias

        return y

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = BatchNorm(64, mode=0)
        self.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.classifier = nn.Linear(16384, 10)

    def forward(self, x, mode):
        out = self.avgpool(self.pool(F.relu(self.bn1(self.conv1(x), mode))))
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        print("==> printing bn1 running mean from NET during forward")
        print("==> printing bn1 running mean from SELF. during forward")
        print("==> printing bn1 running var from NET during forward")
        print("==> printing bn1 running mean from SELF. during forward")
        return out

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader =, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader =, batch_size=64, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = net()
net =
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets =,
        outputs = net(inputs, mode=0)
        loss = criterion(outputs, targets)
        print("==> printing bn1 running mean FROM net after forward")
        print("==> printing bn1 running var FROM net after forward")


        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()


for epoch in range(0, 1):

Here’s what this code prints.

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..

Epoch: 0
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
==> printing bn1 running mean from SELF. during forward
tensor([-0.0128, -0.0358,  0.0290,  0.0318,  0.0084,  0.0128,  0.0154,  0.0134,
         0.0136,  0.0083, -0.0045,  0.0129, -0.0102, -0.0212,  0.0096, -0.0075,
        -0.0218, -0.0206,  0.0209,  0.0205,  0.0054,  0.0289,  0.0007,  0.0021,
         0.0038,  0.0060,  0.0103, -0.0062, -0.0202,  0.0034, -0.0381,  0.0033,
        -0.0023, -0.0251,  0.0124, -0.0383,  0.0060,  0.0007, -0.0519, -0.0023,
         0.0106, -0.0149,  0.0044,  0.0117,  0.0005,  0.0139, -0.0214, -0.0409,
         0.0115,  0.0143,  0.0020, -0.0367, -0.0468,  0.0178,  0.0090,  0.0306,
        -0.0371, -0.0076, -0.0028,  0.0218, -0.0059, -0.0186,  0.0113, -0.0305],
       device='cuda:0', grad_fn=<AddBackward0>)
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9281, 1.1053, 1.0759, 0.9632, 0.9372, 0.9262, 1.0004, 0.9267, 0.9207,
        0.9355, 0.9205, 0.9140, 0.9843, 0.9189, 0.9344, 0.9172, 0.9390, 1.1078,
        1.1116, 0.9229, 0.9183, 0.9362, 0.9684, 0.9877, 0.9519, 0.9155, 0.9422,
        0.9362, 0.9389, 0.9236, 1.0129, 0.9349, 0.9155, 0.9697, 0.9733, 1.0286,
        0.9520, 0.9706, 1.1903, 0.9599, 0.9428, 0.9158, 0.9805, 0.9188, 0.9361,
        0.9651, 0.9629, 1.2728, 1.0130, 0.9128, 0.9790, 1.0832, 1.1244, 0.9504,
        0.9162, 0.9488, 0.9979, 0.9494, 1.0155, 0.9752, 0.9204, 0.9216, 0.9375,
        0.9471], device='cuda:0', grad_fn=<AddBackward0>)
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
==> printing bn1 running mean from SELF. during forward
tensor([-0.0133, -0.0348,  0.0268,  0.0328,  0.0073,  0.0127,  0.0156,  0.0131,
         0.0130,  0.0080, -0.0051,  0.0112, -0.0105, -0.0230,  0.0111, -0.0070,
        -0.0228, -0.0192,  0.0184,  0.0224,  0.0044,  0.0291,  0.0026,  0.0025,
         0.0044,  0.0050,  0.0078, -0.0052, -0.0192,  0.0052, -0.0397,  0.0066,
        -0.0038, -0.0250,  0.0128, -0.0389,  0.0060,  0.0026, -0.0508, -0.0017,
         0.0101, -0.0154,  0.0049,  0.0104, -0.0002,  0.0117, -0.0192, -0.0427,
         0.0111,  0.0154,  0.0009, -0.0371, -0.0472,  0.0195,  0.0097,  0.0306,
        -0.0365, -0.0059, -0.0013,  0.0216, -0.0092, -0.0190,  0.0125, -0.0320],
       device='cuda:1', grad_fn=<AddBackward0>)
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9286, 1.1101, 1.0870, 0.9601, 0.9357, 0.9248, 1.0001, 0.9262, 0.9197,
        0.9331, 0.9196, 0.9129, 0.9832, 0.9175, 0.9312, 0.9172, 0.9359, 1.1148,
        1.1235, 0.9212, 0.9167, 0.9369, 0.9676, 0.9868, 0.9497, 0.9146, 0.9459,
        0.9333, 0.9410, 0.9214, 1.0089, 0.9348, 0.9154, 0.9720, 0.9733, 1.0262,
        0.9516, 0.9689, 1.2014, 0.9553, 0.9422, 0.9149, 0.9757, 0.9174, 0.9340,
        0.9708, 0.9680, 1.2622, 1.0139, 0.9120, 0.9817, 1.0828, 1.1253, 0.9478,
        0.9153, 0.9497, 1.0001, 0.9536, 1.0213, 0.9773, 0.9229, 0.9196, 0.9351,
        0.9416], device='cuda:1', grad_fn=<AddBackward0>)
==> printing bn1 running mean FROM net after forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
==> printing bn1 running var FROM net after forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')

First two blocks are paired, they show running mean and var during forward (since I’m using DataParallel with 2 GPUs, there’re two pairs of the output.) Here, I see that running mean and var get only updated in self.bn1 but this updated is not synced to the network itself.

Could you please help me extend your code to multi GPU version (Dataparallel) ? I am facing the same problem as mentioned by @SeoHyeong.

I’m not sure, why the running stats updates are not gathered to the default device, but using

# instead of
self.running_mean = (...)

seem to perform the updates properly.

CC @SeoHyeong


@ptrblck It seems that pure python impl of bn consumed much more GPU memory during training(7G vs 4G). To my best knowledge, there are only mean, var and output. Besides, I have tried my best to use inplace operation. What wrong happens?

I cannot reproduce the increased memory usage with my manual implementation and comparing it via:

bn = MyBatchNorm2d(3, affine=True).cuda()
# bn = nn.BatchNorm2d(3, affine=True).cuda() # switch to this in another run
for _ in range(10):
    x = torch.randn(16, 3, 224, 224, device='cuda')
    out = bn(x)

and both print a memory allocation of ~18MB.

Hi ,
I am implementing torch.nn.functional.normalize from scratch , found source code of it implementing ,

def normalize(input, p=2, dim=1, eps=1e-12, out=None):
    if not torch.jit.is_scripting():
        if type(input) is not Tensor and has_torch_function((input,)):
            return handle_torch_function(
                normalize, (input,), input, p=p, dim=dim, eps=eps, out=out)
    if out is None:
        denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input)
        return input / denom
        denom = input.norm(p, dim, keepdim=True).clamp_min_(eps).expand_as(input)
        return torch.div(input, denom, out=out)

I found it little bit hard, I have tried to built it as

def normalize(input):
    return input / np.sqrt(np.sum(input*input, axis=-1, keepdims=True))

do you think if it will work? so that i could just use as same as normalize function in pytorch.

What’s your use case you need to reimplement this method?
In your code snippet x is undefined. Also, passing a tensors containing zeros would create an invalid output.

I think it would work, I tried to take results of both as ,
tried to used torch functions,

def normalize(input):
    return torch.div(input, torch.sqrt(torch.sum(input*input, axis=-1, keepdims=True)))

tensor([[-0.6721,  0.6924,  0.2624],
        [ 0.4446, -0.8870, -0.1247]], grad_fn=<DivBackward0>)

and with functional normalize


tensor([[-0.6721,  0.6924,  0.2624],
        [ 0.4446, -0.8870, -0.1247]], grad_fn=<DivBackward0>)

Is it correct?

Note that norm probably is more efficient thant (input*input).sum().sqrt().
But yeah, if you don’t need it to work with __torch_function__s and don’t need to handler the out case and fix p to 2 and dim to -1 and want to ignore the eps that is probably used for the benefit of vanishing norm vectors, it is equivalent.
The explicit broadcasting of expand_as isn’t needed these days as its implicit.

Best regards


Yes thanks :grinning: and my final normalize function looks like this

def normalize(input, p=2, dim=-1):
    return input / input.norm(p,dim,keepdim=True)

neat and clean #fixedtypo


In your manual implementation, you have written that the weights and biases should be different (line:85). Can you please explain why that should be the case?

Does it have to be different for the same input and the same random seed?


I’m initializing the parameters with random values and make sure that they are indeed not equal (as no seeding was used and I also didn’t try to recreated the initialization of the PyTorch implementation):

# Init BatchNorm layers
my_bn = MyBatchNorm2d(3, affine=True)
bn = nn.BatchNorm2d(3, affine=True)

compare_bn(my_bn, bn)  # weight and bias should be different

After loading the state_dict I make sure that they are indeed equal.

I am using your implementation for my experiments. On an RTX 3090, using your implementation drops the training speed by about 30%. My own implementation, one that does not derive from nn.BatchNorm2d but rather simply from nn.Module, also had a similar drop in speed. I also tried other manual implementations with the same 30% drop in training performance. I then wrote a C++ extension using PyTorch’s C++ libs and ATen. It too had the same 30% drop in training speed. I have not tried writing a CUDA kernel for it though.

I am really fascinated by the default implementation. I have a couple of queries regarding this, if you can explain. What makes the default implementation so fast? How can I have a manual implementation that is as fast (or at least within 5-10% of the speed of the default implementation)?

The reason why I am asking this is actually very important. Here in the research community, we almost always write custom modules, like weight standardization, various other ways of custom normalization schemes, or some other ops. In almost all cases, we see a significant drop in speed with our custom modules. Training a vanilla ResNet is so fast, and yet making just a small change to architecture (by introducing a custom layer), drops the training speed significantly. Not only it takes longer for our models to train (thus slowing down research), but it also keeps the GPUs quite occupied (with high utilizations and power draw). Forgive me for not knowing how the default modules of PyTorch work or how they are so fast.


It depends a bit which implementation you are using when you mention the “native” implementation.
E.g. since you are using the GPU, cudnn would be used, which would provide fast algorithms for the batchnorm operations. To disable it, you could use torch.backends.cudnn.enabled = False and compare the speed again to your custom implementation to see, if you would be inside the desired 5-10% performance drop window.

I tried the manual implementation of Batch Normalization. However, the training accuracy seems to fail when I use it. I am not sure what the cause of this maybe.
I looked at the running mean and variance, and it the tensor is all zeros.
Thanks for your help in advance.

I don’t know which manual implementation you are using, but my reference should update the running stats properly.

In my case, I use multipy nn.BatchNorm1d to construct several domain specific BN, and I design the dataloader to allocate the data from different domains to different gpus, such that different BNs get data from different gpus, e.g. ‘gpu1: bn1, gup2: bn2, …, gpu_n: bn_n’. After training, the bn.weight and bn.bias seems update properly, but the only the running_mean (or running_var) of the first bn (on gpu:0) was updated, the mean and var of bn_2, …, bn_n are entirely not updated. How can I solve this problem? Any advices? Thanks