Custom batchnorm2d

hi,
I’m trying to understand the calculation of BatchNorm2d and have made custom BN as follows, but it failed.
could anyone help me find out the mistake? THANKS!!!

class MyBNFunc(Function):
    @staticmethod
    def forward(ctx, input, avg, var, gamma, beta, eps):
        B, C, H, W = input.shape
        ctx.avg = avg
        ctx.var = var
        ctx.eps = eps
        ctx.B = B
        output = input - avg
        output = output / torch.sqrt(var + eps)
        scaled_output = output * gamma + beta
        ctx.save_for_backward(input, gamma, beta, output)
        return scaled_output
    
    @staticmethod
    def backward(ctx, grad_output):
        input, gamma, beta, output = ctx.saved_tensors
        avg = ctx.avg
        var = ctx.var
        eps = ctx.eps
        B = ctx.B
        dL_dxi_hat = grad_output * gamma
        dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
        dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
        dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B 
        dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True) 
        dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
        return dL_dxi, dL_davg, dL_dvar, dL_dgamma, dL_dbeta, None


class MyBN(nn.Module):
    def __init__(self, num_features, momentum=0.1, eps=1e-5):
        super(MyBN, self).__init__()
        self.num_features = num_features
        self.momentum = momentum
        self.weight = Parameter(torch.Tensor(num_features))
        self.bias = Parameter(torch.Tensor(num_features))

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))
        self.reset_parameters()
        self.eps = eps

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.weight.data.uniform_()
        self.bias.data.zero_()

    def forward(self, inp):

        gamma = self.weight.view(1, self.num_features, 1, 1)
        beta = self.bias.view(1, self.num_features, 1, 1)

        if self.training:
            B, C, H, W = inp.shape
            y = inp.transpose(0, 1).contiguous()  # C x B x H x W
            avg = y.view(C, -1).mean(-1)  # C
            self.running_mean = (1 - self.momentum) * avg + self.momentum * self.running_mean
            var = y.view(C, -1).var(-1)
            self.running_var = (1 - self.momentum) * var + self.momentum * self.running_var
        else:
            avg = self.running_mean
            var = self.running_var
        avg = avg.view(1, -1, 1, 1)
        var = var.view(1, -1, 1, 1)

        out = MyBNFunc.apply(inp, avg, var, gamma, beta, self.eps)

        return out

Hi,

Could you be more precise about the error you’re seeing? How does it fail?

I replace nn.BatchNorm2d with MyBN in Resnet18 and use Cifar100 for training.
and it turned out that the loss would be extremely high (e.g. 500) or even nan, and the acc is approximately 1%.
while in the nn.BatchNorm2d case, I could witness the decline of loss and the rise of accuracy.

Have you checked that your version gives the same result that nn.BatchNorm2d when you forward the same random input in both of them?

I have checked the result and found this phenomenon:
if I use “version 1” instead of MyBNFunc, everything works well and it trained successfully.
but if I turned to use the current version, the training failed but the forward output matched exactly.
however, I implemented my function according to the BatchNorm paper, and cannot find what is wrong with it.

class MyBNFunc(Function):
    @staticmethod
    def forward(ctx, input, avg, var, gamma, beta, eps):
        B, C, H, W = input.shape
        avg = avg[None, :, None, None]
        var = var[None, :, None, None]
        ctx.avg = avg
        ctx.var = var
        ctx.eps = eps
        ctx.B = B
        output = input - avg
        output = output / (torch.sqrt(var + eps))
        ctx.save_for_backward(input, gamma, beta, output)
        scaled_output = output * gamma + beta
        return scaled_output
    
    @staticmethod
    def backward(ctx, grad_output):
        input, gamma, beta, output = ctx.saved_tensors
        avg = ctx.avg
        var = ctx.var
        eps = ctx.eps
        B = ctx.B
        dL_dxi_hat = grad_output * gamma
        dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
        dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
        dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B # dL_dxi_hat / sqrt()
        dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True) #(grad_output * output).sum((0, 2, 3), keepdim=True)
        dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
        return dL_dxi, None, None, dL_dgamma, dL_dbeta, None
    
    
class MyBN1(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MyBN1, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var
        weight = self.weight[None, :, None, None]
        bias = self.bias[None, :, None, None]

        output = MyBNFunc.apply(input, mean, var, weight, bias, self.eps)
        '''
        VERSION 1
        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
        '''

        return output

If the forward properly match, you can use the gradcheck method from torch.autograd to check that your backward implementation is correct.

After trial and error, I still could not find where the bug is

class MyBN1(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MyBN1, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
            mean = mean.view(1, -1, 1, 1)
            var = var.view(1, -1, 1, 1)
                
        else:
            mean = self.running_mean
            var = self.running_var
        weight = self.weight.view(1, -1, 1, 1)
        bias = self.bias.view(1, -1, 1, 1)

        output = MyBNFunc.apply(input, mean, var, weight, bias, self.eps)
        
        return output

def mysum(tensor):
    return tensor.sum((0,2,3),keepdim=True)
class MyBNFunc(Function):
    @staticmethod
    def forward(ctx, input, avg, var, gamma, beta, eps):
        ctx.avg = avg
        ctx.var = var
        ctx.eps = eps
        ctx.shape = input.shape
        output = input - avg
        scale = 1 / torch.sqrt(var + eps)
        output = output * scale
        ctx.save_for_backward(input, gamma, beta, output, scale)
        output = output * gamma + beta
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        input, gamma, beta, output, scale = ctx.saved_tensors
        avg = ctx.avg
        var = ctx.var
        eps = ctx.eps
        B, C, H, W = ctx.shape

        dL_dxi_hat = grad_output * gamma
        dl_dvar = mysum(dL_dxi_hat) * (input - avg) * -0.5 * scale * scale * scale
        dl_davg = mysum(dL_dxi_hat * -1.0 * scale) + dl_dvar * mysum(-2.0 * (input - avg)) / (B * H * W)
        dL_dxi = dL_dxi_hat * scale + dl_dvar * 2.0 * (input - avg) / (B*H*W)  + dl_davg / (B*H*W)
        dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True) 
        dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
        return dL_dxi, dl_davg, dl_dvar, dL_dgamma, dL_dbeta, None

compared to the result of torch.nn.BatchNorm2d, the gradient of gamma and beta match exactly, but the input_grad is wrong, could anyone help me?

Could you share your testing script as well that runs both and shows the difference please?

@Mason-Qin have you found your bug? I have a very similar implementation and suffer from the same problem you described.