Mean-only batch-norm

Does any one know a way of implementing mean-only batch norm of https://arxiv.org/abs/1602.07868? I.e. I want to subtract the mean from the activations but not divide them by deviation.

maybe you can just subclass this:

https://pytorch.org/docs/master/_modules/torch/nn/modules/batchnorm.html

and override the forward function setting self.running_var.fill_(1) before calling F.batch_norm(blah blah)

Thanks! But, I want this mean-only behavior for training as well not just for inference. So, fixing runnning variance would not help?

It should work also while training. gamma and beta are still trainable parameters, the only difference is that they are applied to the inputs normalized using the batch mean statistics and fixed variance.

The issue is not beta and gamma. In training time, current batch statistics are used not running averages.

Oh I see, sorry. The normalization is in the cuda code so the simplest way I think is to port the openai code https://github.com/openai/weightnorm
lasagne code is easy

I’m trying to implement it for my task, loss is decreasing… but not sure if correct

class MeanOnlyBatchNorm(nn.Module):
    def __init__(self, num_features, momentum=0.1):
        super(MeanOnlyBatchNorm, self).__init__()
        self.num_features = num_features
        self.momentum = momentum
        self.weight = Parameter(torch.Tensor(num_features))
        self.bias = Parameter(torch.Tensor(num_features))

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.reset_parameters()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.weight.data.uniform_()
        self.bias.data.zero_()

    def forward(self, inp):
        size = list(inp.size())
        gamma = self.weight.view(1, self.num_features, 1, 1)
        beta = self.bias.view(1, self.num_features, 1, 1)

        if self.training:
            avg = torch.mean(inp.view(size[0], self.num_features, -1), dim=2)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * torch.mean(avg.data, dim=0)
        else:
            avg = Variable(self.running_mean.repeat(size[0], 1), requires_grad=False)

        output = inp - avg.view(size[0], size[1], 1, 1)
        output = output*gamma + beta

        return output
1 Like