Custom batchnorm2d

hi,
I’m trying to understand the calculation of BatchNorm2d and have made custom BN as follows, but it failed.
could anyone help me find out the mistake? THANKS!!!

``````class MyBNFunc(Function):
@staticmethod
def forward(ctx, input, avg, var, gamma, beta, eps):
B, C, H, W = input.shape
ctx.avg = avg
ctx.var = var
ctx.eps = eps
ctx.B = B
output = input - avg
output = output / torch.sqrt(var + eps)
scaled_output = output * gamma + beta
ctx.save_for_backward(input, gamma, beta, output)
return scaled_output

@staticmethod
input, gamma, beta, output = ctx.saved_tensors
avg = ctx.avg
var = ctx.var
eps = ctx.eps
B = ctx.B
dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B
dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True)
dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
return dL_dxi, dL_davg, dL_dvar, dL_dgamma, dL_dbeta, None

class MyBN(nn.Module):
def __init__(self, num_features, momentum=0.1, eps=1e-5):
super(MyBN, self).__init__()
self.num_features = num_features
self.momentum = momentum
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))

self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.zeros(num_features))
self.reset_parameters()
self.eps = eps

def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
self.weight.data.uniform_()
self.bias.data.zero_()

def forward(self, inp):

gamma = self.weight.view(1, self.num_features, 1, 1)
beta = self.bias.view(1, self.num_features, 1, 1)

if self.training:
B, C, H, W = inp.shape
y = inp.transpose(0, 1).contiguous()  # C x B x H x W
avg = y.view(C, -1).mean(-1)  # C
self.running_mean = (1 - self.momentum) * avg + self.momentum * self.running_mean
var = y.view(C, -1).var(-1)
self.running_var = (1 - self.momentum) * var + self.momentum * self.running_var
else:
avg = self.running_mean
var = self.running_var
avg = avg.view(1, -1, 1, 1)
var = var.view(1, -1, 1, 1)

out = MyBNFunc.apply(inp, avg, var, gamma, beta, self.eps)

return out``````

Hi,

Could you be more precise about the error you’re seeing? How does it fail?

I replace nn.BatchNorm2d with MyBN in Resnet18 and use Cifar100 for training.
and it turned out that the loss would be extremely high (e.g. 500) or even nan, and the acc is approximately 1%.
while in the nn.BatchNorm2d case, I could witness the decline of loss and the rise of accuracy.

Have you checked that your version gives the same result that nn.BatchNorm2d when you forward the same random input in both of them?

I have checked the result and found this phenomenon:
if I use “version 1” instead of MyBNFunc, everything works well and it trained successfully.
but if I turned to use the current version, the training failed but the forward output matched exactly.
however, I implemented my function according to the BatchNorm paper, and cannot find what is wrong with it.

``````class MyBNFunc(Function):
@staticmethod
def forward(ctx, input, avg, var, gamma, beta, eps):
B, C, H, W = input.shape
avg = avg[None, :, None, None]
var = var[None, :, None, None]
ctx.avg = avg
ctx.var = var
ctx.eps = eps
ctx.B = B
output = input - avg
output = output / (torch.sqrt(var + eps))
ctx.save_for_backward(input, gamma, beta, output)
scaled_output = output * gamma + beta
return scaled_output

@staticmethod
input, gamma, beta, output = ctx.saved_tensors
avg = ctx.avg
var = ctx.var
eps = ctx.eps
B = ctx.B
dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B # dL_dxi_hat / sqrt()
dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True) #(grad_output * output).sum((0, 2, 3), keepdim=True)
dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
return dL_dxi, None, None, dL_dgamma, dL_dbeta, None

class MyBN1(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(MyBN1, self).__init__(
num_features, eps, momentum, affine, track_running_stats)

def forward(self, input):
self._check_input_dim(input)

exponential_average_factor = 0.0

if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None:  # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:  # use exponential moving average
exponential_average_factor = self.momentum

# calculate running estimates
if self.training:
mean = input.mean([0, 2, 3])
# use biased var in train
var = input.var([0, 2, 3], unbiased=False)
n = input.numel() / input.size(1)
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var
weight = self.weight[None, :, None, None]
bias = self.bias[None, :, None, None]

output = MyBNFunc.apply(input, mean, var, weight, bias, self.eps)
'''
VERSION 1
input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
'''

return output

``````

If the forward properly match, you can use the gradcheck method from torch.autograd to check that your backward implementation is correct.

After trial and error, I still could not find where the bug is

``````class MyBN1(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(MyBN1, self).__init__(
num_features, eps, momentum, affine, track_running_stats)

def forward(self, input):
self._check_input_dim(input)

exponential_average_factor = 0.0

if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None:  # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:  # use exponential moving average
exponential_average_factor = self.momentum

# calculate running estimates
if self.training:
mean = input.mean([0, 2, 3])
var = input.var([0, 2, 3], unbiased=False)
n = input.numel() / input.size(1)
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
mean = mean.view(1, -1, 1, 1)
var = var.view(1, -1, 1, 1)

else:
mean = self.running_mean
var = self.running_var
weight = self.weight.view(1, -1, 1, 1)
bias = self.bias.view(1, -1, 1, 1)

output = MyBNFunc.apply(input, mean, var, weight, bias, self.eps)

return output

def mysum(tensor):
return tensor.sum((0,2,3),keepdim=True)
class MyBNFunc(Function):
@staticmethod
def forward(ctx, input, avg, var, gamma, beta, eps):
ctx.avg = avg
ctx.var = var
ctx.eps = eps
ctx.shape = input.shape
output = input - avg
scale = 1 / torch.sqrt(var + eps)
output = output * scale
ctx.save_for_backward(input, gamma, beta, output, scale)
output = output * gamma + beta
return output

@staticmethod
input, gamma, beta, output, scale = ctx.saved_tensors
avg = ctx.avg
var = ctx.var
eps = ctx.eps
B, C, H, W = ctx.shape

dl_dvar = mysum(dL_dxi_hat) * (input - avg) * -0.5 * scale * scale * scale
dl_davg = mysum(dL_dxi_hat * -1.0 * scale) + dl_dvar * mysum(-2.0 * (input - avg)) / (B * H * W)
dL_dxi = dL_dxi_hat * scale + dl_dvar * 2.0 * (input - avg) / (B*H*W)  + dl_davg / (B*H*W)
dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True)
dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
return dL_dxi, dl_davg, dl_dvar, dL_dgamma, dL_dbeta, None

``````

compared to the result of torch.nn.BatchNorm2d, the gradient of gamma and beta match exactly, but the input_grad is wrong, could anyone help me?

Could you share your testing script as well that runs both and shows the difference please?

@Mason-Qin have you found your bug? I have a very similar implementation and suffer from the same problem you described.