hi,
I’m trying to understand the calculation of BatchNorm2d and have made custom BN as follows, but it failed.
could anyone help me find out the mistake? THANKS!!!
class MyBNFunc(Function):
@staticmethod
def forward(ctx, input, avg, var, gamma, beta, eps):
B, C, H, W = input.shape
ctx.avg = avg
ctx.var = var
ctx.eps = eps
ctx.B = B
output = input - avg
output = output / torch.sqrt(var + eps)
scaled_output = output * gamma + beta
ctx.save_for_backward(input, gamma, beta, output)
return scaled_output
@staticmethod
def backward(ctx, grad_output):
input, gamma, beta, output = ctx.saved_tensors
avg = ctx.avg
var = ctx.var
eps = ctx.eps
B = ctx.B
dL_dxi_hat = grad_output * gamma
dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B
dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True)
dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
return dL_dxi, dL_davg, dL_dvar, dL_dgamma, dL_dbeta, None
class MyBN(nn.Module):
def __init__(self, num_features, momentum=0.1, eps=1e-5):
super(MyBN, self).__init__()
self.num_features = num_features
self.momentum = momentum
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.zeros(num_features))
self.reset_parameters()
self.eps = eps
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.fill_(1)
self.weight.data.uniform_()
self.bias.data.zero_()
def forward(self, inp):
gamma = self.weight.view(1, self.num_features, 1, 1)
beta = self.bias.view(1, self.num_features, 1, 1)
if self.training:
B, C, H, W = inp.shape
y = inp.transpose(0, 1).contiguous() # C x B x H x W
avg = y.view(C, -1).mean(-1) # C
self.running_mean = (1 - self.momentum) * avg + self.momentum * self.running_mean
var = y.view(C, -1).var(-1)
self.running_var = (1 - self.momentum) * var + self.momentum * self.running_var
else:
avg = self.running_mean
var = self.running_var
avg = avg.view(1, -1, 1, 1)
var = var.view(1, -1, 1, 1)
out = MyBNFunc.apply(inp, avg, var, gamma, beta, self.eps)
return out