I have been trying to find a way to use a large batch size on limited GPU memory. Gradient accumulation is a good solution, and I am able to use a larger x2 or x4 batch size. See How to implement accumulated gradient？ - #16 by sbelharbi
However, this is not true for BN as the mean and variance are still calculated for small batch size which can fit on the GPU, also pointed out by @ptrblck.
As a work around, I thought to implement a training setup that has two phases,
- Phase One: a forward pass using a large batch (say 16), update the BN running mean/var as well as save the current mean/var from the batch of 16 images. Followed by clearing the gradients for the whole model (model.zero_grad()) | no backward pass.
- Phase 2: splits the batch into two i.e. 2x 8 batch size, two accumulate gradients (backward pass), and a single optimization step. However, in this phase, the batch normalization does not update the running mean/var and also uses the mean/var for the batch size 16 that was saved in phase one.
However, this implementation results in much higher training loss, as well as lower validation metric (IOU).
The BN implementation
import torch import torch.nn as nn class CustomBatchNorm2d(nn.BatchNorm2d): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(CustomBatchNorm2d, self).__init__( num_features, eps, momentum, affine, track_running_stats) self.set_phase(phase=1) self.register_buffer('batch_mean', torch.zeros(num_features)) self.register_buffer('batch_var', torch.ones(num_features)) def set_phase(self, phase=1): self.phase = phase def end_cycle(self): # reset the current batch mean and variance self.batch_mean.zero_() self.batch_var.fill_(1) self.set_phase(phase =1) def forward(self, input): self._check_input_dim(input) exponential_average_factor = 0.0 if self.training and self.track_running_stats: if self.phase == 1: #only update the running mean and var in phase one if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum # calculate running estimates if self.training: mean = input.mean([0, 2, 3]) # use biased var in train var = input.var([0, 2, 3], unbiased=False) n = input.numel() / input.size(1) if self.phase == 1: # pahse one, update the running mean and var, as well as save the curr mean and var with torch.no_grad(): self.running_mean = exponential_average_factor * mean\ + (1 - exponential_average_factor) * self.running_mean # update running_var with unbiased var self.running_var = exponential_average_factor * var * n / (n - 1)\ + (1 - exponential_average_factor) * self.running_var self.batch_mean.copy_(mean) self.batch_var.copy_(var) elif self.phase == 2 : # pahse two, use the saved mean and variance for the whole batch mean = self.batch_mean var = self.batch_var else: # eval, use the running mean and variance mean = self.running_mean var = self.running_var input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps)) if self.affine: input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None] return input
Also the training loop for one batch (single iteration)
model.train() model.zero_grad() for layer in model.modules(): if isinstance(layer, nn.BatchNorm2d): layer.set_phase(phase=1) with torch.no_grad(): _ = model(imgs) model.zero_grad() model.train() for layer in model.modules(): if isinstance(layer, nn.BatchNorm2d): layer.set_phase(phase=2) splits = 4 split_batch = int(imgs.size()/ splits) for split in range(splits): imgs_split = imgs[split*split_batch:(split+1)*split_batch,:,:,:] loss = ....... # get loss for the split loss /= splits loss.backward() # accumulate gradients optimizer.step() for layer in model.modules(): if isinstance(layer, nn.BatchNorm2d): layer.end_cycle()
Would be helpful if someone can suggest what I may be doing incorrect here