How does BatchNorm keeps track of running_mean?

Hi all,

I try to implement a custom version of batch norm from scratch (to use a complex-valued network, I already have the other components, just need the batch norm).
For now, I want to write it in pure python using Module + Function:
https://pytorch.org/docs/stable/notes/extending.html

To help me, I read the way batch norm is coded in:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html
and
https://pytorch.org/docs/stable/_modules/torch/nn/functional.html

The module is defined in torch.nn.modules.batchnorm, where running_mean and running_var are created as buffers and then passed to the forward function that called nn.functional.batch_norm that takes running_mean as argument, itself calling the cpp batchnorm function.

But at no point running_mean is updated in the python code I read, how does the Module keeps track of it if it is only evaluated in the cpp code? What would be the best way to do it in pure python?

Thanks for your help,

SMP

The stats are updated in the ATen implementation.

If you would like to create a custom Python implementation, you could initialize the layer similar to the current BatchNorm implementation (i.e. register the affine parameters as nn.Parameter and the running estimates as buffers) and update the estimates manually.

Let me know, if you get stuck somewhere.

3 Likes

Hi ptrblck,

Thank you very much for your reply. I have never used the cpp API and I am not sure how the python Module gets updated with the new value of current_mean but for writing a custom Python implementation, I guess the only thing I need to know (for now) is that I have to do it manually. Thanks for the tip.

Now for the practical part; if I want to keep it as close to the original structure as possible, and also as done in the example of the “extending Pytorch tutorial”, I have to create an autograd Function what will do the calculation and an nn Module that will call this function in forward(). However, since the current_mean would be calculated inside the Function (as done in the original batchnorm structure), how to allow the Module to keep track of it? Should I simply calculate the current_mean inside the Module?

I am sorry if I am not very clear, it is my first customizing experience in pyTorch, I am a physicist attempting a conversion…

Thanks again for your help !

SMP

You could calculate the current mean and var inside the forward method of your custom batch norm layer.
Note that the backward pass can automatically be calculated if your forward method just uses PyTorch functions, so that you don’t necessarily need to write a custom autograd.Function.

I’ve created a Python implementation of the nn.BatchNorm2d layer here, which might be a good starter for your experiments.
I haven’t timed this approach yet, but I assume it to be a lot slower than the PyTorch implementation.

3 Likes

You are amazing, your example is exactly what I needed!

So using PyTorch functions is enough for the Method to calculate the backward, then I am not sure to understand the Linear/LinearFunction example of the docs, as it seems to use only torch functions too, right?

I agree that it may be significantly slower, but I try to go step by step. It seems that complexed-value tensors are not so relevant for most learning applications, that is at least my conclusion seeing that their implementation is not a priority at all. However, for physics applications, it is of the utmost importance as transformations are usually complex, even is the phase is often a hidden variable of the problem. So I try to get my hands dirty.

Thanks for your help,

SMP

1 Like

You would have to write the backward method, if you are using another library inside the method (e.g. numpy) or if you think the automatically generated backward method might be slower than your custom one.
The example just shows the general approach of doing so.

Ok got it, it makes sense.

Thanks

Hi again,

I have a few remarks about your code:

Line 63, I guess it should be n = input.numel() / input.size(1)

Then for the calculation of the variance, var = input.var([0, 2, 3], unbiased=False) gives me:

TypeError: var(): argument 'dim' (position 1) must be int, not list

I have pytorch 1.1.0a0+ac206a9, maybe it is version related, however, using
var = input.transpose(1,0).contiguous().view(input.size()[1],-1).var(-1, unbiased=False)
did the trick for me.

By running your test, the max difference between the parameters of the stock batchnorm and the custom one is about 1e-6, I’d say good enough for me. I did not quantify the time difference yet, but I ran a network with it, and it is significantly slower, but I do not mind right now.

I implemented the complex version following your example and this paper, it should work, but I am a bit worry about the speed now at it is heavier than the real version. The ball is in my court now.

Thanks again,

SMP

I’m using 1.0.0.dev20190312 and am not sure why .var() isn’t taking multiple dimensions in your case, but good you found a workaround.

The difference of 1e-6 is most likely created by floating point precision, so you don’t have to worry about it.

Once your custom layer works fine, we could try to write a more performant extension, so that this layer isn’t a bottleneck in your model.

Hi again,

So I came up with some code to replace critical parts of NNs with complex functions so that it works on two N,C,H,W tensors for 2D or two N,C tensors for 1D in the form

xr,xi = complexLayer(xr,xi)

The code is here

It works nicely, and without time cost except for the complex batch norm. It uses the 2x2 covariance matrix between the real and imaginary parts, so there is calculation performed in python.

I have three issues with that complex batch norm code:

  1. It is slow. I really do not want to go into cpp code, especially regarding point 3 below, so is there any easy and/or automatic way to optimize the code, like numba style? I would perform some quantitative test when I have some time.

  2. There seems to be some memory leak, as the memory usage of my GPU increases over time till it crashes. When I use the “naive” complex batch norm approach, where is do standard batch norm independently on the real and imaginary part, I do not have memory issue, the ram usage is constant. Is there a need to del/free something inside the forward to prevent memory issues?

  3. More a remark than an issue, I did not see any significant improvement compared to the naive approach. It is true also in the paper I took it from but I work on a physical system where quantities are complex, so I hope it will improve things. Still, I do not give up so soon, I will continue testing.

Hi,

I am still struggling with those issues, does anybody have some advice about that matter?

Best,

SMP

  1. Not sure how much this will help, but you could try to optimize your code using torch.jit.trace and compare the performance.

  2. Does ComplexBatchNorm2d have this memory leak or another module? I could have a look at it and try to narrow down possible issues.

  3. Do you mean the complex modules do not improve the performance of your overall model compared to just the “real” input?

Hi ptrblck,

Thank you for your kind reply.

  1. Thanks, I will look at it.

  2. Yes, it is ComplexBatchNorm2d (and ComplexBatchNorm1D) that has this memory leak. It looks like the memory is not freed correctly, but since Python usually takes care of this for me, I do not know how to deal with such issue.

  3. Sorry I was not very clear. No, complex NNs give me better results! Especially since I use those networks to learn from physical systems where there are internal complex variables. What does not show (obvious) significant improvements is the ComplexBatchNorm() (the batch norm that takes into account the covariance matrix between the real and imaginary part) compared to NaiveComplexBatchNorm() (which basically perform batch norm independently on the real and imaginary part). But it is likely problem dependent, I need to perform further tests once my memory issue corrected.

I cannot reproduce the memory leak using this simple code:

device = 'cuda'
bn = ComplexBatchNorm2D(3)
bn.to(device)
x_r = torch.randn(10, 3, 24, 24, device=device)
x_i = torch.randn(10, 3, 24, 24, device=device)

print(torch.cuda.max_memory_allocated() / 1024**3)

for epoch in range(100):
    bn.zero_grad()
    output = bn(x_r, x_i)
    print('Epoch {}, max mem {}'.format(
        epoch, torch.cuda.max_memory_allocated() / 1024**3))
    output[0].mean().backward()
    output[1].mean().backward()

Could you take a look and give me a hint how your usage differs from mine?

Hi ptrblck,

Thanks for your reply.

I tested your code, and there is no memory issue when I run it. However, I still have those issues when I use ComplexBatchNorm2D in an actual neural network.

I tried to write an as-minimal-as-I-can example to replicate my issues. When I run the following, the allocated memory does increase. If I comment the lines corresponding to the batch norm, it does not.

Any idea?

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from complexLayers import ComplexBatchNorm2D

batch_size = 64
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = datasets.MNIST('../data', train=True, transform=trans, download=True)
test_set = datasets.MNIST('../data', train=False, transform=trans, download=True)

train_loader = torch.utils.data.DataLoader(train_set, batch_size= batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size= batch_size, shuffle=True)

class ComplexNet(nn.Module):
    
    def __init__(self):
        super(ComplexNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.bn  = ComplexBatchNorm2D(20)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
             
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x,_ = self.bn(x,x)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    
device = torch.device("cuda:3" )
model = ComplexNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

for epoch in range(100):
    train(model, device, train_loader, optimizer, epoch)
    print('Epoch {} '.format(epoch))
    print('Max allocated memory: {:.0f}MiB'.format(torch.cuda.max_memory_allocated(device = device) / 1024**2))
1 Like

Thank you very much for this great code to reproduce this issue!

Indeed the memory is growing in each epoch.
After looking into the code, I think the reason is that you might track the computation graph in self.running_mean and self.running_covar unintentionally.
This might be the case if you assign a value with a grad_fn to these tensors:

self.running_mean = exponential_average_factor * mean\
    + (1 - exponential_average_factor) * self.running_mean  # mean is holding onto the computation graph

If you wrap the update codes into a with torch.no_grad() guard, the memory footprint stays constant:

class ComplexBatchNorm2D(_ComplexBatchNorm):     
    
    def forward(self, input_r, input_i):
        assert(input_r.size() == input_i.size())
        assert(len(input_r.shape) == 4)
        #self._check_input_dim(input)

        exponential_average_factor = 0.0
        

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum
        
        if self.training:
             
            # calculate mean of real and imaginary part
            mean_r = input_r.mean([0, 2, 3])
            mean_i = input_i.mean([0, 2, 3])
            
            mean = torch.stack((mean_r,mean_i),dim=1)
            
            with torch.no_grad():
                # update running mean
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
            
            # works for 2d 
            input_r = input_r-mean_r[None, :, None, None]
            input_i = input_i-mean_i[None, :, None, None]          
        
            # Elements of the covariance matrix (biased for train)
            n = input_r.numel() / input_r.size(1)
            Crr = 1./n*input_r.pow(2).sum(dim=[0,2,3])+self.eps
            Cii = 1./n*input_i.pow(2).sum(dim=[0,2,3])+self.eps
            Cri = (input_r.mul(input_i)).mean(dim=[0,2,3])

            with torch.no_grad():
    
                self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_covar[:,0]
                
                self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_covar[:,1]
                    
                self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_covar[:,2]
                
        else:
            mean = self.running_mean
            Crr = self.running_covar[:,0]+self.eps
            Cii = self.running_covar[:,1]+self.eps
            Cri = self.running_covar[:,2]#+self.eps

            input_r = input_r-mean[None,:,0,None,None]
            input_i = input_i-mean[None,:,1,None,None]



        

            
        # caclualte the inverse square root the covariance matrix
        det = Crr*Cii-Cri.pow(2)
        s = torch.sqrt(det)
        t = torch.sqrt(Cii+Crr + 2 * s)
        inverse_st = 1.0 / (s * t)
        Rrr = (Cii + s) * inverse_st
        Rii = (Crr + s) * inverse_st
        Rri = -Cri * inverse_st
                  
        input_r, input_i = Rrr[None,:,None,None]*input_r+Rri[None,:,None,None]*input_i, \
                           Rii[None,:,None,None]*input_i+Rri[None,:,None,None]*input_r
                           
        if self.affine:
            input_r, input_i = self.weight[None,:,0,None,None]*input_r+self.weight[None,:,2,None,None]*input_i+\
                               self.bias[None,:,0,None,None], \
                               self.weight[None,:,2,None,None]*input_r+self.weight[None,:,1,None,None]*input_i+\
                               self.bias[None,:,1,None,None]

        return input_r, input_i
2 Likes

Great! That solves my memory issue, no more leaks. Thanks a lot!

I will be able to perform more tests and try to optimize the code.

Just a question, is it an anticipated behaviour of PyTorch? It would seem logical to me that using register_buffer should forbid such tracking. Is there a better way to initialize running_mean?

Thanks again!

SMP

Although you register the tensor as a buffer initially, the assignment re-creates self.running_xxx thus storing the computation graph.

I’m not sure if some kind of “recreation guard” would make sense and I rather tend to manually disable the gradient in case there is some kind of special case, where this behavior is needed.

Would using detach also give the same result?

self.running_mean = exponential_average_factor * mean.detach()\
                    + (1 - exponential_average_factor) * self.running_mean

Yep, it seems to do the job too, thanks!

I do not know if one is best, but I like the with torch.no_grad() as it makes obvious for the reader what the tricky part is about.