I have been trying to find a way to use a large batch size on limited GPU memory. Gradient accumulation is a good solution, and I am able to use a larger x2 or x4 batch size. See How to implement accumulated gradient? - #16 by sbelharbi
However, this is not true for BN as the mean and variance are still calculated for small batch size which can fit on the GPU, also pointed out by @ptrblck.
As a work around, I thought to implement a training setup that has two phases,
- Phase One: a forward pass using a large batch (say 16), update the BN running mean/var as well as save the current mean/var from the batch of 16 images. Followed by clearing the gradients for the whole model (model.zero_grad()) | no backward pass.
- Phase 2: splits the batch into two i.e. 2x 8 batch size, two accumulate gradients (backward pass), and a single optimization step. However, in this phase, the batch normalization does not update the running mean/var and also uses the mean/var for the batch size 16 that was saved in phase one.
However, this implementation results in much higher training loss, as well as lower validation metric (IOU).
The BN implementation
import torch
import torch.nn as nn
class CustomBatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(CustomBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
self.set_phase(phase=1)
self.register_buffer('batch_mean', torch.zeros(num_features))
self.register_buffer('batch_var', torch.ones(num_features))
def set_phase(self, phase=1):
self.phase = phase
def end_cycle(self): # reset the current batch mean and variance
self.batch_mean.zero_()
self.batch_var.fill_(1)
self.set_phase(phase =1)
def forward(self, input):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.phase == 1: #only update the running mean and var in phase one
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if self.training:
mean = input.mean([0, 2, 3])
# use biased var in train
var = input.var([0, 2, 3], unbiased=False)
n = input.numel() / input.size(1)
if self.phase == 1: # pahse one, update the running mean and var, as well as save the curr mean and var
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
self.batch_mean.copy_(mean)
self.batch_var.copy_(var)
elif self.phase == 2 : # pahse two, use the saved mean and variance for the whole batch
mean = self.batch_mean
var = self.batch_var
else: # eval, use the running mean and variance
mean = self.running_mean
var = self.running_var
input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
return input
Also the training loop for one batch (single iteration)
model.train()
model.zero_grad()
for layer in model.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.set_phase(phase=1)
with torch.no_grad():
_ = model(imgs)
model.zero_grad()
model.train()
for layer in model.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.set_phase(phase=2)
splits = 4
split_batch = int(imgs.size()[0]/ splits)
for split in range(splits):
imgs_split = imgs[split*split_batch:(split+1)*split_batch,:,:,:]
loss = ....... # get loss for the split
loss /= splits
loss.backward() # accumulate gradients
optimizer.step()
for layer in model.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.end_cycle()
Would be helpful if someone can suggest what I may be doing incorrect here