Implementing Batchnorm in Pytorch. Problem with updating self.running_mean and self.running_var

I’m trying to implement batch normalization in pytorch and apply it into VGG16 network. Here’s my batchnorm below.

class BatchNorm(nn.Module):
    def __init__(self, input, mode, momentum=0.9, epsilon=1e-05):
        '''
        input: assume 4D input (mini_batch_size, # channel, w, h)
        momentum: momentum for exponential average
        '''
        super(BatchNorm, self).__init__()
        #self.run_mode = run_mode
        #self.input_shape = input.shape
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.momentum = momentum
        self.run_mode = 0 # 0: training, 1: testing
        self.insize = input
        self.epsilon = epsilon

        # initialize weight(gamma), bias(beta), running mean and variance
        U = uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
        self.weight = nn.Parameter(U.sample(torch.Size([self.insize])).view(self.insize)) ## TODO
        self.bias = nn.Parameter(torch.zeros(self.insize)) ## TODO
        self.running_mean = torch.zeros(self.insize)
        self.running_var = torch.ones(self.insize)


    # def set_runmode(self, run_mode):
    #     self.run_mode = run_mode

    def forward(self, input, mode):
        if mode == 0:
            mean = input.mean([0,2,3]) # along channel axis
            var = input.var([0,2,3])
            # update running mean and var
            running_mean_current = self.momentum * self.running_mean
            running_mean_current = running_mean_current.to(self.device)
            self.running_mean = running_mean_current + (1.0-self.momentum) * mean
            running_var_current = self.momentum * self.running_var
            running_var_current = running_var_current.to(self.device)
            self.running_var = running_var_current + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*var)

            # change shape
            current_mean = mean.view([1, self.insize, 1, 1]).expand_as(input)
            current_var = var.view([1, self.insize, 1, 1]).expand_as(input)
            current_weight = self.weight.view([1, self.insize, 1, 1]).expand_as(input)
            current_bias = self.bias.view([1, self.insize, 1, 1]).expand_as(input)
            # get output
            y = current_weight * (input - current_mean) / (
                        current_var + self.epsilon).sqrt() + current_bias

        else:
            mean = self.running_mean
            var = self.running_var

            # change shape
            current_mean = mean.view([1, self.insize, 1, 1]).expand_as(input)
            current_var = var.view([1, self.insize, 1, 1]).expand_as(input)
            current_weight = self.weight.view([1, self.insize, 1, 1]).expand_as(input)
            current_bias = self.bias.view([1, self.insize, 1, 1]).expand_as(input)
            # get output
            y = current_weight.data.cpu() * (input.data.cpu() - current_mean) / (
                        current_var + self.epsilon).sqrt() + current_bias.data.cpu()
            y = y.cuda()

        return y

and here is how customized batchnorm is being called in the VGG16 network.

class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1_1 = batchnorm.BatchNorm(64, mode=0)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn1_2 = batchnorm.BatchNorm(64, mode=0)

        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2_1 = batchnorm.BatchNorm(128, mode=0)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn2_2 = batchnorm.BatchNorm(128, mode=0)

        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3_1 = batchnorm.BatchNorm(256, mode=0)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn3_2 = batchnorm.BatchNorm(256, mode=0)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn3_3 = batchnorm.BatchNorm(256, mode=0)

        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn4_1 = batchnorm.BatchNorm(512, mode=0)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn4_2 = batchnorm.BatchNorm(512, mode=0)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn4_3 = batchnorm.BatchNorm(512, mode=0)

        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn5_1 = batchnorm.BatchNorm(512, mode=0)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn5_2 = batchnorm.BatchNorm(512, mode=0)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn5_3 = batchnorm.BatchNorm(512, mode=0)

        self.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.classifier = nn.Linear(512, 10)

    def forward(self, x, mode):
        out = F.relu(self.bn1_1(self.conv1_1(x), mode))
        out = self.pool(F.relu(self.bn1_2(self.conv1_2(out), mode)))
        out = F.relu(self.bn2_1(self.conv2_1(out), mode))
        out = self.pool(F.relu(self.bn2_2(self.conv2_2(out), mode)))
        out = F.relu(self.bn3_1(self.conv3_1(out), mode))
        out = F.relu(self.bn3_2(self.conv3_2(out), mode))
        out = self.pool(F.relu(self.bn3_3(self.conv3_3(out), mode)))
        out = F.relu(self.bn4_1(self.conv4_1(out), mode))
        out = F.relu(self.bn4_2(self.conv4_2(out), mode))
        out = self.pool(F.relu(self.bn4_3(self.conv4_3(out), mode)))
        out = F.relu(self.bn5_1(self.conv5_1(out), mode))
        out = F.relu(self.bn5_2(self.conv5_2(out), mode))
        out = self.avgpool(self.pool(F.relu(self.bn5_3(self.conv5_3(out), mode))))
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

However, I figured out that whenever network gets trained, running_mean and running_var remains same as the initialization (0 and 1) respectively. I can’t figure out which part I’m missing. Any help would be appreciated!

1 Like

I additionally found out that I get no problem using only single GPU, but for above situation, I’m using DataParallel with 2 GPUs. According to DataParallel example (https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html), half of inputs goes to cuda:0 and the other goes to cuda:1.

How can I adjust implemented bathnorm with DataParallel?

Could you have a look at my manual implementation and compare both approaches?

1 Like

Thank you for the quick reply. I’ll do take a look at it and get back! Quick question, I’ve came across the use of self.register_buffer() on self.running_mean and self.running_var. Would that change any operation regarding them?

register_buffer makes sure to add tensors, which do not require gradients, to the internal state_dict, so that they can be saved and restored.

I’ve reviewed your implementation of BatchNorm2d. However, since I’m trying to implement it from the scratch (not inheriting anything from the built-in nn.BatchNorm), there’re things still unclear to me.

I slightly modified BatchNorm as follows.

class BatchNorm(nn.Module):
    def __init__(self, input, mode, momentum=0.9, epsilon=1e-05):
        '''
        input: assume 4D input (mini_batch_size, # channel, w, h)
        momentum: momentum for exponential average
        '''
        super(BatchNorm, self).__init__()
        self.momentum = momentum
        self.run_mode = 0 # 0: training, 1: testing
        self.insize = input
        self.epsilon = epsilon

        # initialize weight(gamma), bias(beta), running mean and variance
        U = uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
        self.weight = nn.Parameter(U.sample(torch.Size([self.insize])).view(self.insize))
        self.bias = nn.Parameter(torch.zeros(self.insize))
        self.register_buffer('running_mean', torch.zeros(self.insize)) # this solves cpu and cuda mismatch location issue
        self.register_buffer('running_var', torch.ones(self.insize))

        # self.running_mean = torch.zeros(self.insize) # torch.zeros(self.insize)
        # self.running_var = torch.ones(self.insize)

        self.reset_parameters()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)


    def forward(self, input, mode):
        if mode == 0:
            mean = input.mean([0,2,3]) # along channel axis
            var = input.var([0,2,3])
            self.running_mean = (self.momentum * self.running_mean) + (1.0-self.momentum) * mean # .to(input.device)
            self.running_var = (self.momentum * self.running_var) + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*var)

        else:
            mean = self.running_mean
            var = self.running_var

        # change shape
        current_mean = mean.view([1, self.insize, 1, 1]).expand_as(input)
        current_var = var.view([1, self.insize, 1, 1]).expand_as(input)
        current_weight = self.weight.view([1, self.insize, 1, 1]).expand_as(input)
        current_bias = self.bias.view([1, self.insize, 1, 1]).expand_as(input)

        # get output
        y = current_weight * (input - current_mean) / (current_var + self.epsilon).sqrt() + current_bias

        return y

In the training process, which goes like below,

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs, mode=0)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

whenever I hit the line outputs = net(inputs, mode=0) I see running mean and var gets calculated and weight and bias get updated. However, as soon as I return back to train.py and hit loss = criterion(outputs, targets), running mean and var get initialized again to 0 and 1.

ps. I find it super weird since I’ve checked updated running mean and var are kept well when I use only single GPU. This issue happens when I try to use multiple GPU with nn.DataParallel

For multiple GPUs, the running estimates of the default device should be used.
Could you post an executable code snippet, which reproduces this issue?

Here’s the executable code snippet that reproduces the issue I’m having.
As I print out running mean and variance during forward() step, I see my BatchNorm(bn1) somehow does not gets updated within my network.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import uniform
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

class BatchNorm(nn.Module):
    def __init__(self, input, mode, momentum=0.9, epsilon=1e-05):
        super(BatchNorm, self).__init__()
        self.momentum = momentum
        self.run_mode = 0 # 0: training, 1: testing
        self.insize = input
        self.epsilon = epsilon

        # initialize weight(gamma), bias(beta), running mean and variance
        U = uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0]))
        self.weight = nn.Parameter(U.sample(torch.Size([self.insize])).view(self.insize))
        self.bias = nn.Parameter(torch.zeros(self.insize))
        self.register_buffer('running_mean', torch.zeros(self.insize)) # this solves cpu and cuda mismatch location issue
        self.register_buffer('running_var', torch.ones(self.insize))
        self.reset_parameters()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)

    def forward(self, input, mode):
        if mode == 0:
            mean = input.mean([0,2,3]) # along channel axis
            var = input.var([0,2,3])
            self.running_mean = (self.momentum * self.running_mean) + (1.0-self.momentum) * mean # .to(input.device)
            self.running_var = (self.momentum * self.running_var) + (1.0-self.momentum) * (input.shape[0]/(input.shape[0]-1)*var)

        else:
            mean = self.running_mean
            var = self.running_var

        # change shape
        current_mean = mean.view([1, self.insize, 1, 1]).expand_as(input)
        current_var = var.view([1, self.insize, 1, 1]).expand_as(input)
        current_weight = self.weight.view([1, self.insize, 1, 1]).expand_as(input)
        current_bias = self.bias.view([1, self.insize, 1, 1]).expand_as(input)

        # get output
        y = current_weight * (input - current_mean) / (current_var + self.epsilon).sqrt() + current_bias

        return y



class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = BatchNorm(64, mode=0)
        self.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.classifier = nn.Linear(16384, 10)

    def forward(self, x, mode):
        out = self.avgpool(self.pool(F.relu(self.bn1(self.conv1(x), mode))))
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        print("======================================================")
        print("==> printing bn1 running mean from NET during forward")
        print(net.module.bn1.running_mean)
        print("==> printing bn1 running mean from SELF. during forward")
        print(self.bn1.running_mean)
        print("==> printing bn1 running var from NET during forward")
        print(net.module.bn1.running_var)
        print("==> printing bn1 running mean from SELF. during forward")
        print(self.bn1.running_var)
        return out

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = net()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs, mode=0)
        loss = criterion(outputs, targets)
        print("====================================================")
        print("==> printing bn1 running mean FROM net after forward")
        print(net.module.bn1.running_mean)
        print("==> printing bn1 running var FROM net after forward")
        print(net.module.bn1.running_var)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        break


for epoch in range(0, 1):
    train(epoch)

Here’s what this code prints.

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..

Epoch: 0
======================================================
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([-0.0128, -0.0358,  0.0290,  0.0318,  0.0084,  0.0128,  0.0154,  0.0134,
         0.0136,  0.0083, -0.0045,  0.0129, -0.0102, -0.0212,  0.0096, -0.0075,
        -0.0218, -0.0206,  0.0209,  0.0205,  0.0054,  0.0289,  0.0007,  0.0021,
         0.0038,  0.0060,  0.0103, -0.0062, -0.0202,  0.0034, -0.0381,  0.0033,
        -0.0023, -0.0251,  0.0124, -0.0383,  0.0060,  0.0007, -0.0519, -0.0023,
         0.0106, -0.0149,  0.0044,  0.0117,  0.0005,  0.0139, -0.0214, -0.0409,
         0.0115,  0.0143,  0.0020, -0.0367, -0.0468,  0.0178,  0.0090,  0.0306,
        -0.0371, -0.0076, -0.0028,  0.0218, -0.0059, -0.0186,  0.0113, -0.0305],
       device='cuda:0', grad_fn=<AddBackward0>)
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9281, 1.1053, 1.0759, 0.9632, 0.9372, 0.9262, 1.0004, 0.9267, 0.9207,
        0.9355, 0.9205, 0.9140, 0.9843, 0.9189, 0.9344, 0.9172, 0.9390, 1.1078,
        1.1116, 0.9229, 0.9183, 0.9362, 0.9684, 0.9877, 0.9519, 0.9155, 0.9422,
        0.9362, 0.9389, 0.9236, 1.0129, 0.9349, 0.9155, 0.9697, 0.9733, 1.0286,
        0.9520, 0.9706, 1.1903, 0.9599, 0.9428, 0.9158, 0.9805, 0.9188, 0.9361,
        0.9651, 0.9629, 1.2728, 1.0130, 0.9128, 0.9790, 1.0832, 1.1244, 0.9504,
        0.9162, 0.9488, 0.9979, 0.9494, 1.0155, 0.9752, 0.9204, 0.9216, 0.9375,
        0.9471], device='cuda:0', grad_fn=<AddBackward0>)
======================================================
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([-0.0133, -0.0348,  0.0268,  0.0328,  0.0073,  0.0127,  0.0156,  0.0131,
         0.0130,  0.0080, -0.0051,  0.0112, -0.0105, -0.0230,  0.0111, -0.0070,
        -0.0228, -0.0192,  0.0184,  0.0224,  0.0044,  0.0291,  0.0026,  0.0025,
         0.0044,  0.0050,  0.0078, -0.0052, -0.0192,  0.0052, -0.0397,  0.0066,
        -0.0038, -0.0250,  0.0128, -0.0389,  0.0060,  0.0026, -0.0508, -0.0017,
         0.0101, -0.0154,  0.0049,  0.0104, -0.0002,  0.0117, -0.0192, -0.0427,
         0.0111,  0.0154,  0.0009, -0.0371, -0.0472,  0.0195,  0.0097,  0.0306,
        -0.0365, -0.0059, -0.0013,  0.0216, -0.0092, -0.0190,  0.0125, -0.0320],
       device='cuda:1', grad_fn=<AddBackward0>)
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9286, 1.1101, 1.0870, 0.9601, 0.9357, 0.9248, 1.0001, 0.9262, 0.9197,
        0.9331, 0.9196, 0.9129, 0.9832, 0.9175, 0.9312, 0.9172, 0.9359, 1.1148,
        1.1235, 0.9212, 0.9167, 0.9369, 0.9676, 0.9868, 0.9497, 0.9146, 0.9459,
        0.9333, 0.9410, 0.9214, 1.0089, 0.9348, 0.9154, 0.9720, 0.9733, 1.0262,
        0.9516, 0.9689, 1.2014, 0.9553, 0.9422, 0.9149, 0.9757, 0.9174, 0.9340,
        0.9708, 0.9680, 1.2622, 1.0139, 0.9120, 0.9817, 1.0828, 1.1253, 0.9478,
        0.9153, 0.9497, 1.0001, 0.9536, 1.0213, 0.9773, 0.9229, 0.9196, 0.9351,
        0.9416], device='cuda:1', grad_fn=<AddBackward0>)
====================================================
==> printing bn1 running mean FROM net after forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
==> printing bn1 running var FROM net after forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')

First two blocks are paired, they show running mean and var during forward (since I’m using DataParallel with 2 GPUs, there’re two pairs of the output.) Here, I see that running mean and var get only updated in self.bn1 but this updated is not synced to the network itself.

1 Like

Could you please help me extend your code to multi GPU version (Dataparallel) ? I am facing the same problem as mentioned by @SeoHyeong.

I’m not sure, why the running stats updates are not gathered to the default device, but using

self.running_mean.copy_(...)
# instead of
self.running_mean = (...)

seem to perform the updates properly.

CC @SeoHyeong