Updating running_mean and running_var in a custom Batchnorm?

Hi,

I have been trying to implement a custom batch normalization function such that it can be extended to the Multi GPU version, in particular, the DataParallel module in Pytorch. The custom batchnorm works alright when using 1 GPU, but, when extended to 2 or more, the running mean and variance work in the forward function, but when it returns back from the network, the mean and variance are reinitialized to 0 and 1.

The torch.nn.DataParallel mentions in the warning section that " In each forward, module is replicated on each device, so any updates to the running module in forward will be lost. For example, if module has a counter attribute that is incremented in each forward, it will always stay at the initial value because the update is done on the replicas which are destroyed after forward." But I am not really sure how to retain the mean and variance from the default device.

I have provided code with the result obtained during multi GPU training. This code utilizes the Batchnorm provided by @ptrblck here and extension of this code.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import uniform
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torch.nn.parameter import Parameter

class ptrblck_BatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(ptrblck_BatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input


class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = ptrblck_BatchNorm2d(64)
        print("==> printing bn1 mean when init")
        print(self.bn1.running_mean)
        print("==> printing bn1 when init")
        print(self.bn1.running_mean)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.classifier = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.avgpool(x)
        
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        print("======================================================")
        print("==> printing bn1 running mean from NET during forward")
        print(net.module.bn1.running_mean)
        print("==> printing bn1 running mean from SELF. during forward")
        print(self.bn1.running_mean)
        print("==> printing bn1 running var from NET during forward")
        print(net.module.bn1.running_var)
        print("==> printing bn1 running mean from SELF. during forward")
        print(self.bn1.running_var)
        return x

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = net()
net = torch.nn.DataParallel(net).cuda()
print('Number of GPU {}'.format(torch.cuda.device_count()))

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        print("====================================================")
        print("==> printing bn1 running mean FROM net after forward")
        print(net.module.bn1.running_mean)
        print("==> printing bn1 running var FROM net after forward")
        print(net.module.bn1.running_var)
        
        break


for epoch in range(0, 1):
    train(epoch)

Result:

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..
==> printing bn1 mean when init
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
==> printing bn1 when init
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Number of GPU 2

Epoch: 0
======================================================
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([ 0.0053,  0.0010, -0.0077, -0.0290,  0.0241,  0.0258, -0.0048,  0.0151,
        -0.0133,  0.0080,  0.0197, -0.0042, -0.0188,  0.0233,  0.0310, -0.0230,
        -0.0133,  0.0222,  0.0119, -0.0042, -0.0220, -0.0169, -0.0342, -0.0025,
         0.0338, -0.0070,  0.0202,  0.0050,  0.0108,  0.0008,  0.0363,  0.0347,
        -0.0106,  0.0082,  0.0128,  0.0074,  0.0111, -0.0030, -0.0089,  0.0070,
        -0.0262, -0.0029,  0.0053, -0.0136, -0.0183,  0.0045, -0.0014, -0.0221,
         0.0132,  0.0064,  0.0388, -0.0220, -0.0008,  0.0400, -0.0187,  0.0397,
        -0.0131, -0.0176,  0.0035,  0.0055, -0.0270,  0.0066, -0.0149,  0.0135],
       device='cuda:0')
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9665, 0.9073, 0.9220, 1.0947, 1.0687, 0.9624, 0.9252, 0.9131, 0.9066,
        0.9536, 0.9258, 0.9203, 1.0359, 0.9690, 1.1066, 1.0636, 0.9135, 0.9644,
        0.9373, 0.9846, 0.9696, 0.9454, 1.0459, 0.9245, 0.9778, 0.9709, 0.9352,
        0.9995, 0.9657, 0.9510, 1.0943, 1.0171, 0.9298, 1.0747, 0.9341, 0.9635,
        0.9978, 0.9303, 0.9261, 0.9137, 0.9569, 1.0066, 1.0463, 0.9955, 0.9621,
        0.9172, 0.9836, 0.9817, 0.9086, 0.9576, 1.0905, 0.9861, 0.9661, 1.1773,
        0.9345, 1.0904, 0.9133, 1.0660, 0.9164, 0.9058, 0.9446, 0.9225, 1.0914,
        0.9292], device='cuda:0')
======================================================
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([-0.0020,  0.0002, -0.0103, -0.0426,  0.0386,  0.0311, -0.0059,  0.0151,
        -0.0140,  0.0145,  0.0218, -0.0029, -0.0281,  0.0284,  0.0449, -0.0329,
        -0.0107,  0.0278,  0.0135, -0.0123, -0.0260, -0.0214, -0.0423, -0.0035,
         0.0410, -0.0097,  0.0276,  0.0102,  0.0197, -0.0001,  0.0483,  0.0451,
        -0.0078,  0.0190,  0.0135, -0.0004,  0.0196, -0.0028, -0.0140,  0.0070,
        -0.0332, -0.0110,  0.0151, -0.0210, -0.0226,  0.0074, -0.0088, -0.0314,
         0.0125, -0.0003,  0.0505, -0.0312,  0.0086,  0.0544, -0.0245,  0.0528,
        -0.0086, -0.0290,  0.0063,  0.0042, -0.0339,  0.0061, -0.0277,  0.0092],
       device='cuda:1')
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9665, 0.9072, 0.9211, 1.0999, 1.0714, 0.9610, 0.9209, 0.9125, 0.9063,
        0.9553, 0.9260, 0.9189, 1.0386, 0.9706, 1.1139, 1.0610, 0.9121, 0.9660,
        0.9366, 0.9886, 0.9683, 0.9454, 1.0511, 0.9227, 0.9792, 0.9704, 0.9330,
        0.9989, 0.9657, 0.9476, 1.1008, 1.0191, 0.9294, 1.0814, 0.9320, 0.9642,
        1.0006, 0.9287, 0.9254, 0.9128, 0.9559, 1.0100, 1.0521, 0.9972, 0.9621,
        0.9168, 0.9849, 0.9803, 0.9083, 0.9556, 1.0946, 0.9865, 0.9651, 1.1880,
        0.9330, 1.0959, 0.9116, 1.0706, 0.9149, 0.9057, 0.9450, 0.9215, 1.0972,
        0.9261], device='cuda:1')
====================================================
==> printing bn1 running mean FROM net after forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
==> printing bn1 running var FROM net after forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')

How can I make sure that the running estimates of the default device be used? Currently, I am not working towards synchronized Batchnorm.

Thank you,
Best,
Shreyas Kamath

Answered in this topic.

Thank you so much! That works!

Also, I think it might be related to the way the following code is executed

net = torch.nn.DataParallel(net).cuda()

When I changed it to

net = net.cuda()
net = torch.nn.DataParallel(net)

The running mean and var did get updated.