Work around for BN with small batch size

I have been trying to find a way to use a large batch size on limited GPU memory. Gradient accumulation is a good solution, and I am able to use a larger x2 or x4 batch size. See How to implement accumulated gradient? - #16 by sbelharbi

However, this is not true for BN as the mean and variance are still calculated for small batch size which can fit on the GPU, also pointed out by @ptrblck.

As a work around, I thought to implement a training setup that has two phases,

  • Phase One: a forward pass using a large batch (say 16), update the BN running mean/var as well as save the current mean/var from the batch of 16 images. Followed by clearing the gradients for the whole model (model.zero_grad()) | no backward pass.
  • Phase 2: splits the batch into two i.e. 2x 8 batch size, two accumulate gradients (backward pass), and a single optimization step. However, in this phase, the batch normalization does not update the running mean/var and also uses the mean/var for the batch size 16 that was saved in phase one.

However, this implementation results in much higher training loss, as well as lower validation metric (IOU).

The BN implementation

import torch
import torch.nn as nn

class CustomBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(CustomBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)
        self.set_phase(phase=1)

        self.register_buffer('batch_mean', torch.zeros(num_features))
        self.register_buffer('batch_var', torch.ones(num_features))

    def set_phase(self, phase=1):
        self.phase = phase

    def end_cycle(self): # reset the current batch mean and variance
        self.batch_mean.zero_()
        self.batch_var.fill_(1)
        self.set_phase(phase =1)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.phase == 1: #only update the running mean and var in phase one 
                if self.num_batches_tracked is not None:
                    self.num_batches_tracked += 1
                    if self.momentum is None:  # use cumulative moving average
                        exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                    else:  # use exponential moving average
                        exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            if self.phase == 1:        # pahse one, update the running mean and var, as well as save the curr mean and var
                with torch.no_grad():
                    self.running_mean = exponential_average_factor * mean\
                        + (1 - exponential_average_factor) * self.running_mean
                    # update running_var with unbiased var
                    self.running_var = exponential_average_factor * var * n / (n - 1)\
                        + (1 - exponential_average_factor) * self.running_var
                    self.batch_mean.copy_(mean)
                    self.batch_var.copy_(var)
            elif self.phase == 2 : # pahse two, use the saved mean and variance for the whole batch 
                 mean = self.batch_mean
                 var = self.batch_var
        else: # eval, use the running mean and variance
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input

Also the training loop for one batch (single iteration)

    model.train()
    model.zero_grad()  

  
    for layer in model.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.set_phase(phase=1)

    with torch.no_grad():
        _ = model(imgs)
    model.zero_grad()
    model.train()

    for layer in model.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.set_phase(phase=2)

    splits = 4
    split_batch = int(imgs.size()[0]/ splits)

    for split in range(splits):
        imgs_split = imgs[split*split_batch:(split+1)*split_batch,:,:,:]
        loss = ....... # get loss for the split 
        loss /= splits
        loss.backward() # accumulate gradients 

    optimizer.step()

    for layer in model.modules():
        if isinstance(layer, nn.BatchNorm2d):
            layer.end_cycle()

Would be helpful if someone can suggest what I may be doing incorrect here